diff --git a/.gitignore b/.gitignore index 85baa82ae0b0..bb32477feda6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ # Temporary output of build tools bazel-* *.out + +# Repomix outputs +repomix*.xml \ No newline at end of file diff --git a/go.mod b/go.mod index e9e0cb8f0b84..2cc07de02964 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( golang.org/x/sys v0.36.0 golang.org/x/time v0.10.0 google.golang.org/api v0.198.0 - google.golang.org/grpc v1.74.2 + google.golang.org/grpc v1.75.0 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.33.4 k8s.io/apiextensions-apiserver v0.33.4 @@ -153,8 +153,8 @@ require ( golang.org/x/text v0.29.0 // indirect golang.org/x/tools v0.37.0 // indirect gomodules.xyz/jsonpatch/v2 v2.5.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c // indirect google.golang.org/protobuf v1.36.8 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 2f3c1dc5e82a..0a828023bb57 100644 --- a/go.sum +++ b/go.sum @@ -507,8 +507,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gomodules.xyz/jsonpatch/v2 v2.5.0 h1:JELs8RLM12qJGXU4u/TO3V25KW8GreMKl9pdkk14RM0= gomodules.xyz/jsonpatch/v2 v2.5.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= -gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca h1:PupagGYwj8+I4ubCxcmcBRk3VlUWtTg5huQpZR9flmE= gonum.org/v1/gonum v0.0.0-20181121035319-3f7ecaa7e8ca/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/netlib v0.0.0-20181029234149-ec6d1f5cefe6/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= google.golang.org/api v0.198.0 h1:OOH5fZatk57iN0A7tjJQzt6aPfYQ1JiWkt1yGseazks= google.golang.org/api v0.198.0/go.mod h1:/Lblzl3/Xqqk9hw/yS97TImKTUwnf1bv89v7+OagJzc= @@ -517,17 +518,17 @@ google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= +google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 h1:FiusG7LWj+4byqhbvmB+Q93B/mOxJLN2DTozDuZm4EU= +google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7/go.mod h1:kXqgZtrWaf6qS3jZOCnCH7WYfrvFjkC51bM8fz3RsCA= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c h1:qXWI/sQtv5UKboZ/zUk7h+mrf/lXORyI+n9DKDAusdg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c/go.mod h1:gw1tLEfykwDz2ET4a12jcXt4couGAm7IwsVaTy0Sflo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.74.2 h1:WoosgB65DlWVC9FqI82dGsZhWFNBSLjQ84bjROOpMu4= -google.golang.org/grpc v1.74.2/go.mod h1:CtQ+BGjaAIXHs/5YS3i473GqwBBa1zGQNevxdeBEXrM= +google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4= +google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pkg/activator/net/throttler.go b/pkg/activator/net/throttler.go index 0ef298c48db4..daf768ee0ce3 100644 --- a/pkg/activator/net/throttler.go +++ b/pkg/activator/net/throttler.go @@ -100,7 +100,13 @@ func (p *podTracker) Capacity() int { if p.b == nil { return 1 } - return p.b.Capacity() + capacity := p.b.Capacity() + // Safe conversion: breaker capacity is always reasonable for int + // Check for overflow before conversion + if capacity > 0x7FFFFFFF { + return 0x7FFFFFFF // Return max int32 value + } + return int(capacity) } func (p *podTracker) UpdateConcurrency(c int) { @@ -118,7 +124,7 @@ func (p *podTracker) Reserve(ctx context.Context) (func(), bool) { } type breaker interface { - Capacity() int + Capacity() uint64 Maybe(ctx context.Context, thunk func()) error UpdateConcurrency(int) Reserve(ctx context.Context) (func(), bool) @@ -721,8 +727,13 @@ func newInfiniteBreaker(logger *zap.SugaredLogger) *infiniteBreaker { } // Capacity returns the current capacity of the breaker -func (ib *infiniteBreaker) Capacity() int { - return int(ib.concurrency.Load()) +func (ib *infiniteBreaker) Capacity() uint64 { + concurrency := ib.concurrency.Load() + // Safe conversion: concurrency is int32 and we check for non-negative + if concurrency >= 0 { + return uint64(concurrency) + } + return 0 } func zeroOrOne(x int) int32 { diff --git a/pkg/activator/net/throttler_test.go b/pkg/activator/net/throttler_test.go index ed727a26c79b..bef964fdef61 100644 --- a/pkg/activator/net/throttler_test.go +++ b/pkg/activator/net/throttler_test.go @@ -226,7 +226,7 @@ func TestThrottlerUpdateCapacity(t *testing.T) { rt.breaker = newInfiniteBreaker(logger) } rt.updateCapacity(tt.capacity) - if got := rt.breaker.Capacity(); got != tt.want { + if got := rt.breaker.Capacity(); got != uint64(tt.want) { t.Errorf("Capacity = %d, want: %d", got, tt.want) } if tt.checkAssignedPod { @@ -560,7 +560,7 @@ func TestThrottlerSuccesses(t *testing.T) { rt.mux.RLock() defer rt.mux.RUnlock() if *cc != 0 { - return rt.activatorIndex.Load() != -1 && rt.breaker.Capacity() == wantCapacity && + return rt.activatorIndex.Load() != -1 && rt.breaker.Capacity() == uint64(wantCapacity) && sortedTrackers(rt.assignedTrackers), nil } // If CC=0 then verify number of backends, rather the capacity of breaker. @@ -638,7 +638,7 @@ func TestPodAssignmentFinite(t *testing.T) { if got, want := trackerDestSet(rt.assignedTrackers), sets.New("ip0", "ip4"); !got.Equal(want) { t.Errorf("Assigned trackers = %v, want: %v, diff: %s", got, want, cmp.Diff(want, got)) } - if got, want := rt.breaker.Capacity(), 2*42; got != want { + if got, want := rt.breaker.Capacity(), uint64(2*42); got != want { t.Errorf("TotalCapacity = %d, want: %d", got, want) } if got, want := rt.assignedTrackers[0].Capacity(), 42; got != want { @@ -657,7 +657,7 @@ func TestPodAssignmentFinite(t *testing.T) { if got, want := len(rt.assignedTrackers), 0; got != want { t.Errorf("NumAssignedTrackers = %d, want: %d", got, want) } - if got, want := rt.breaker.Capacity(), 0; got != want { + if got, want := rt.breaker.Capacity(), uint64(0); got != want { t.Errorf("TotalCapacity = %d, want: %d", got, want) } } @@ -687,7 +687,7 @@ func TestPodAssignmentInfinite(t *testing.T) { if got, want := len(rt.assignedTrackers), 3; got != want { t.Errorf("NumAssigned trackers = %d, want: %d", got, want) } - if got, want := rt.breaker.Capacity(), 1; got != want { + if got, want := rt.breaker.Capacity(), uint64(1); got != want { t.Errorf("TotalCapacity = %d, want: %d", got, want) } if got, want := rt.assignedTrackers[0].Capacity(), 1; got != want { @@ -703,7 +703,7 @@ func TestPodAssignmentInfinite(t *testing.T) { if got, want := len(rt.assignedTrackers), 0; got != want { t.Errorf("NumAssignedTrackers = %d, want: %d", got, want) } - if got, want := rt.breaker.Capacity(), 0; got != want { + if got, want := rt.breaker.Capacity(), uint64(0); got != want { t.Errorf("TotalCapacity = %d, want: %d", got, want) } } @@ -935,7 +935,7 @@ func TestInfiniteBreaker(t *testing.T) { } // Verify initial condition. - if got, want := b.Capacity(), 0; got != want { + if got, want := b.Capacity(), uint64(0); got != want { t.Errorf("Cap=%d, want: %d", got, want) } if _, ok := b.Reserve(context.Background()); ok != true { @@ -949,7 +949,7 @@ func TestInfiniteBreaker(t *testing.T) { } b.UpdateConcurrency(1) - if got, want := b.Capacity(), 1; got != want { + if got, want := b.Capacity(), uint64(1); got != want { t.Errorf("Cap=%d, want: %d", got, want) } @@ -976,7 +976,7 @@ func TestInfiniteBreaker(t *testing.T) { if err := b.Maybe(ctx, nil); err == nil { t.Error("Should have failed, but didn't") } - if got, want := b.Capacity(), 0; got != want { + if got, want := b.Capacity(), uint64(0); got != want { t.Errorf("Cap=%d, want: %d", got, want) } @@ -1212,7 +1212,7 @@ func TestAssignSlice(t *testing.T) { t.Errorf("Got=%v, want: %v; diff: %s", got, want, cmp.Diff(want, got, opt)) } - if got, want := got[0].b.Capacity(), 0; got != want { + if got, want := got[0].b.Capacity(), uint64(0); got != want { t.Errorf("Capacity for the tail pod = %d, want: %d", got, want) } }) diff --git a/pkg/queue/breaker.go b/pkg/queue/breaker.go index 918f57b743a5..4c774718f419 100644 --- a/pkg/queue/breaker.go +++ b/pkg/queue/breaker.go @@ -43,7 +43,7 @@ type BreakerParams struct { // executions in excess of the concurrency limit. Function call attempts // beyond the limit of the queue are failed immediately. type Breaker struct { - inFlight atomic.Int64 + pending atomic.Int64 totalSlots int64 sem *semaphore @@ -83,10 +83,10 @@ func NewBreaker(params BreakerParams) *Breaker { func (b *Breaker) tryAcquirePending() bool { // This is an atomic version of: // - // if inFlight == totalSlots { + // if pending == totalSlots { // return false // } else { - // inFlight++ + // pending++ // return true // } // @@ -96,11 +96,12 @@ func (b *Breaker) tryAcquirePending() bool { // (it fails if we're raced to it) or if we don't fulfill the condition // anymore. for { - cur := b.inFlight.Load() + cur := b.pending.Load() + // 10000 + containerConcurrency = totalSlots if cur == b.totalSlots { return false } - if b.inFlight.CompareAndSwap(cur, cur+1) { + if b.pending.CompareAndSwap(cur, cur+1) { return true } } @@ -108,7 +109,7 @@ func (b *Breaker) tryAcquirePending() bool { // releasePending releases a slot on the pending "queue". func (b *Breaker) releasePending() { - b.inFlight.Add(-1) + b.pending.Add(-1) } // Reserve reserves an execution slot in the breaker, to permit @@ -154,9 +155,9 @@ func (b *Breaker) Maybe(ctx context.Context, thunk func()) error { return nil } -// InFlight returns the number of requests currently in flight in this breaker. -func (b *Breaker) InFlight() int { - return int(b.inFlight.Load()) +// Pending returns the number of requests currently pending to this breaker. +func (b *Breaker) Pending() int { + return int(b.pending.Load()) } // UpdateConcurrency updates the maximum number of in-flight requests. @@ -165,10 +166,15 @@ func (b *Breaker) UpdateConcurrency(size int) { } // Capacity returns the number of allowed in-flight requests on this breaker. -func (b *Breaker) Capacity() int { +func (b *Breaker) Capacity() uint64 { return b.sem.Capacity() } +// InFlight returns the number of requests currently in-flight on this breaker. +func (b *Breaker) InFlight() uint64 { + return b.sem.InFlight() +} + // newSemaphore creates a semaphore with the desired initial capacity. func newSemaphore(maxCapacity, initialCapacity int) *semaphore { queue := make(chan struct{}, maxCapacity) @@ -288,9 +294,15 @@ func (s *semaphore) updateCapacity(size int) { } // Capacity is the capacity of the semaphore. -func (s *semaphore) Capacity() int { +func (s *semaphore) Capacity() uint64 { capacity, _ := unpack(s.state.Load()) - return int(capacity) //nolint:gosec // TODO(dprotaso) - capacity should be uint64 + return capacity +} + +// InFlight is the number of the inflight requests of the semaphore. +func (s *semaphore) InFlight() uint64 { + _, inFlight := unpack(s.state.Load()) + return inFlight } // unpack takes an uint64 and returns two uint32 (as uint64) comprised of the leftmost diff --git a/pkg/queue/breaker_test.go b/pkg/queue/breaker_test.go index 547959a1da54..c7e838f82bdc 100644 --- a/pkg/queue/breaker_test.go +++ b/pkg/queue/breaker_test.go @@ -212,12 +212,12 @@ func TestBreakerUpdateConcurrency(t *testing.T) { params := BreakerParams{QueueDepth: 1, MaxConcurrency: 1, InitialCapacity: 0} b := NewBreaker(params) b.UpdateConcurrency(1) - if got, want := b.Capacity(), 1; got != want { + if got, want := b.Capacity(), uint64(1); got != want { t.Errorf("Capacity() = %d, want: %d", got, want) } b.UpdateConcurrency(0) - if got, want := b.Capacity(), 0; got != want { + if got, want := b.Capacity(), uint64(0); got != want { t.Errorf("Capacity() = %d, want: %d", got, want) } } @@ -294,12 +294,12 @@ func TestSemaphoreRelease(t *testing.T) { func TestSemaphoreUpdateCapacity(t *testing.T) { const initialCapacity = 1 sem := newSemaphore(3, initialCapacity) - if got, want := sem.Capacity(), 1; got != want { + if got, want := sem.Capacity(), uint64(1); got != want { t.Errorf("Capacity = %d, want: %d", got, want) } sem.acquire(context.Background()) sem.updateCapacity(initialCapacity + 2) - if got, want := sem.Capacity(), 3; got != want { + if got, want := sem.Capacity(), uint64(3); got != want { t.Errorf("Capacity = %d, want: %d", got, want) } } diff --git a/pkg/queue/drain/signals.go b/pkg/queue/drain/signals.go new file mode 100644 index 000000000000..cdcc78708484 --- /dev/null +++ b/pkg/queue/drain/signals.go @@ -0,0 +1,56 @@ +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package drain + +const ( + // SignalDirectory is the directory where drain signal files are created + SignalDirectory = "/var/run/knative" + + // DrainStartedFile indicates that pod termination has begun and queue-proxy is handling shutdown + DrainStartedFile = SignalDirectory + "/drain-started" + + // DrainCompleteFile indicates that queue-proxy has finished draining requests + DrainCompleteFile = SignalDirectory + "/drain-complete" + + // DrainCheckInterval is how often to check for drain completion + DrainCheckInterval = "1" // seconds + + // ExponentialBackoffDelays are the delays in seconds for checking drain-started file + // Total max wait time: 1+2+4+8 = 15 seconds + ExponentialBackoffDelays = "1 2 4 8" +) + +// BuildDrainWaitScript generates the shell script for waiting on drain signals. +// If existingCommand is provided, it will be executed before the drain wait. +func BuildDrainWaitScript(existingCommand string) string { + drainLogic := `for i in ` + ExponentialBackoffDelays + `; do ` + + ` if [ -f ` + DrainStartedFile + ` ]; then ` + + ` until [ -f ` + DrainCompleteFile + ` ]; do sleep ` + DrainCheckInterval + `; done; ` + + ` exit 0; ` + + ` fi; ` + + ` sleep $i; ` + + `done; ` + + `exit 0` + + if existingCommand != "" { + return existingCommand + "; " + drainLogic + } + return drainLogic +} + +// QueueProxyPreStopScript is the script executed by queue-proxy's PreStop hook +const QueueProxyPreStopScript = "touch " + DrainStartedFile diff --git a/pkg/queue/drain/signals_test.go b/pkg/queue/drain/signals_test.go new file mode 100644 index 000000000000..7e4bc52cf79a --- /dev/null +++ b/pkg/queue/drain/signals_test.go @@ -0,0 +1,242 @@ +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package drain + +import ( + "strings" + "testing" +) + +func TestConstants(t *testing.T) { + tests := []struct { + name string + got string + expected string + }{ + { + name: "SignalDirectory", + got: SignalDirectory, + expected: "/var/run/knative", + }, + { + name: "DrainStartedFile", + got: DrainStartedFile, + expected: "/var/run/knative/drain-started", + }, + { + name: "DrainCompleteFile", + got: DrainCompleteFile, + expected: "/var/run/knative/drain-complete", + }, + { + name: "DrainCheckInterval", + got: DrainCheckInterval, + expected: "1", + }, + { + name: "ExponentialBackoffDelays", + got: ExponentialBackoffDelays, + expected: "1 2 4 8", + }, + { + name: "QueueProxyPreStopScript", + got: QueueProxyPreStopScript, + expected: "touch /var/run/knative/drain-started", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.got != tt.expected { + t.Errorf("got %q, want %q", tt.got, tt.expected) + } + }) + } +} + +func TestBuildDrainWaitScript(t *testing.T) { + tests := []struct { + name string + existingCommand string + wantContains []string + wantExact bool + }{ + { + name: "without existing command", + existingCommand: "", + wantContains: []string{ + "for i in 1 2 4 8", + "if [ -f /var/run/knative/drain-started ]", + "until [ -f /var/run/knative/drain-complete ]", + "sleep 1", + "sleep $i", + "exit 0", + }, + wantExact: false, + }, + { + name: "with existing command", + existingCommand: "echo 'custom prestop'", + wantContains: []string{ + "echo 'custom prestop'", + "for i in 1 2 4 8", + "if [ -f /var/run/knative/drain-started ]", + "until [ -f /var/run/knative/drain-complete ]", + "sleep 1", + "sleep $i", + "exit 0", + }, + wantExact: false, + }, + { + name: "with complex existing command", + existingCommand: "/bin/sh -c 'kill -TERM 1 && wait'", + wantContains: []string{ + "/bin/sh -c 'kill -TERM 1 && wait'", + "for i in 1 2 4 8", + "if [ -f /var/run/knative/drain-started ]", + "until [ -f /var/run/knative/drain-complete ]", + }, + wantExact: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := BuildDrainWaitScript(tt.existingCommand) + + for _, want := range tt.wantContains { + if !strings.Contains(got, want) { + t.Errorf("BuildDrainWaitScript() missing expected substring %q\nGot: %q", want, got) + } + } + + // Verify the command structure + if tt.existingCommand != "" { + // Should start with the existing command + if !strings.HasPrefix(got, tt.existingCommand+"; ") { + t.Errorf("BuildDrainWaitScript() should start with existing command followed by '; '\nGot: %q", got) + } + } + + // Verify the script ends with exit 0 + if !strings.HasSuffix(got, "exit 0") { + t.Errorf("BuildDrainWaitScript() should end with 'exit 0'\nGot: %q", got) + } + }) + } +} + +func TestBuildDrainWaitScriptStructure(t *testing.T) { + // Test the exact structure of the generated script without existing command + got := BuildDrainWaitScript("") + expected := "for i in 1 2 4 8; do " + + " if [ -f /var/run/knative/drain-started ]; then " + + " until [ -f /var/run/knative/drain-complete ]; do sleep 1; done; " + + " exit 0; " + + " fi; " + + " sleep $i; " + + "done; " + + "exit 0" + + if got != expected { + t.Errorf("BuildDrainWaitScript(\"\") structure mismatch\nGot: %q\nExpected: %q", got, expected) + } +} + +func TestBuildDrainWaitScriptWithCommandStructure(t *testing.T) { + // Test the exact structure of the generated script with existing command + existingCmd := "echo 'test'" + got := BuildDrainWaitScript(existingCmd) + expected := "echo 'test'; for i in 1 2 4 8; do " + + " if [ -f /var/run/knative/drain-started ]; then " + + " until [ -f /var/run/knative/drain-complete ]; do sleep 1; done; " + + " exit 0; " + + " fi; " + + " sleep $i; " + + "done; " + + "exit 0" + + if got != expected { + t.Errorf("BuildDrainWaitScript with command structure mismatch\nGot: %q\nExpected: %q", got, expected) + } +} + +func TestBuildDrainWaitScriptEdgeCases(t *testing.T) { + tests := []struct { + name string + existingCommand string + checkFunc func(t *testing.T, result string) + }{ + { + name: "empty string produces valid script", + existingCommand: "", + checkFunc: func(t *testing.T, result string) { + if result == "" { + t.Error("BuildDrainWaitScript(\"\") should not return empty string") + } + if !strings.Contains(result, "for i in") { + t.Error("BuildDrainWaitScript(\"\") should contain for loop") + } + }, + }, + { + name: "command with semicolon", + existingCommand: "cmd1; cmd2", + checkFunc: func(t *testing.T, result string) { + if !strings.HasPrefix(result, "cmd1; cmd2; ") { + t.Error("BuildDrainWaitScript should preserve command with semicolons") + } + }, + }, + { + name: "command with special characters", + existingCommand: "echo '$VAR' && test -f /tmp/file", + checkFunc: func(t *testing.T, result string) { + if !strings.HasPrefix(result, "echo '$VAR' && test -f /tmp/file; ") { + t.Error("BuildDrainWaitScript should preserve special characters") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BuildDrainWaitScript(tt.existingCommand) + tt.checkFunc(t, result) + }) + } +} + +func BenchmarkBuildDrainWaitScript(b *testing.B) { + testCases := []struct { + name string + command string + }{ + {"NoCommand", ""}, + {"SimpleCommand", "echo test"}, + {"ComplexCommand", "/bin/sh -c 'kill -TERM 1 && wait'"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + for range b.N { + _ = BuildDrainWaitScript(tc.command) + } + }) + } +} diff --git a/pkg/queue/request_metric.go b/pkg/queue/request_metric.go index a1406d2c41ce..50c4f2063b2c 100644 --- a/pkg/queue/request_metric.go +++ b/pkg/queue/request_metric.go @@ -85,7 +85,7 @@ func (h *appRequestMetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Requ startTime := h.clock.Now() if h.breaker != nil { - h.queueLen.Record(r.Context(), int64(h.breaker.InFlight())) + h.queueLen.Record(r.Context(), int64(h.breaker.Pending())) } defer func() { // Filter probe requests for revision metrics. diff --git a/pkg/queue/sharedmain/handlers.go b/pkg/queue/sharedmain/handlers.go index 86a1694def04..cf4162fa03a4 100644 --- a/pkg/queue/sharedmain/handlers.go +++ b/pkg/queue/sharedmain/handlers.go @@ -18,8 +18,11 @@ package sharedmain import ( "context" + "fmt" "net" "net/http" + "strings" + "sync/atomic" "time" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -46,6 +49,7 @@ func mainHandler( logger *zap.SugaredLogger, mp metric.MeterProvider, tp trace.TracerProvider, + pendingRequests *atomic.Int32, ) (http.Handler, *pkghandler.Drainer) { target := net.JoinHostPort("127.0.0.1", env.UserPort) tracer := tp.Tracer("knative.dev/serving/pkg/queue") @@ -73,6 +77,7 @@ func mainHandler( composedHandler = requestAppMetricsHandler(logger, composedHandler, breaker, mp) composedHandler = queue.ProxyHandler(tracer, breaker, stats, composedHandler) + composedHandler = queue.ForwardedShimHandler(composedHandler) composedHandler = handler.NewTimeoutHandler(composedHandler, "request timeout", func(r *http.Request) (time.Duration, time.Duration, time.Duration) { return timeout, responseStartTimeout, idleTimeout @@ -90,6 +95,8 @@ func mainHandler( } composedHandler = drainer + composedHandler = withRequestCounter(composedHandler, pendingRequests) + if env.Observability.EnableRequestLog { // We want to capture the probes/healthchecks in the request logs. // Hence we need to have RequestLogHandler be the first one. @@ -105,11 +112,10 @@ func mainHandler( return !netheader.IsProbe(r) }), ) - return composedHandler, drainer } -func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkghandler.Drainer) http.Handler { +func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkghandler.Drainer, pendingRequests *atomic.Int32) http.Handler { mux := http.NewServeMux() mux.HandleFunc(queue.RequestQueueDrainPath, func(w http.ResponseWriter, r *http.Request) { logger.Info("Attached drain handler from user-container", r) @@ -130,6 +136,17 @@ func adminHandler(ctx context.Context, logger *zap.SugaredLogger, drainer *pkgha w.WriteHeader(http.StatusOK) }) + // New endpoint that returns 200 only when all requests are drained + mux.HandleFunc("/drain-complete", func(w http.ResponseWriter, r *http.Request) { + if pendingRequests.Load() <= 0 { + w.WriteHeader(http.StatusOK) + w.Write([]byte("drained")) + } else { + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, "pending requests: %d", pendingRequests.Load()) + } + }) + return mux } @@ -145,3 +162,29 @@ func withFullDuplex(h http.Handler, enableFullDuplex bool, logger *zap.SugaredLo h.ServeHTTP(w, r) }) } + +func isProbeRequest(r *http.Request) bool { + // Check standard probes (K8s and Knative probe headers) + if netheader.IsProbe(r) { + return true + } + + // Check all Knative internal probe user agents that should not be counted + // as pending requests (matching what the Drainer filters) + userAgent := r.Header.Get("User-Agent") + return strings.HasPrefix(userAgent, netheader.ActivatorUserAgent) || + strings.HasPrefix(userAgent, netheader.AutoscalingUserAgent) || + strings.HasPrefix(userAgent, netheader.QueueProxyUserAgent) || + strings.HasPrefix(userAgent, netheader.IngressReadinessUserAgent) +} + +func withRequestCounter(h http.Handler, pendingRequests *atomic.Int32) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only count non-probe requests as pending + if !isProbeRequest(r) { + pendingRequests.Add(1) + defer pendingRequests.Add(-1) + } + h.ServeHTTP(w, r) + }) +} diff --git a/pkg/queue/sharedmain/handlers_integration_test.go b/pkg/queue/sharedmain/handlers_integration_test.go new file mode 100644 index 000000000000..aba6a5cadd32 --- /dev/null +++ b/pkg/queue/sharedmain/handlers_integration_test.go @@ -0,0 +1,259 @@ +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sharedmain + +import ( + "net" + "net/http" + "net/http/httptest" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/trace" + "go.uber.org/zap" + + netstats "knative.dev/networking/pkg/http/stats" + "knative.dev/pkg/network" + "knative.dev/serving/pkg/observability" +) + +func TestMainHandlerWithPendingRequests(t *testing.T) { + logger := zap.NewNop().Sugar() + tp := trace.NewTracerProvider() + mp := metric.NewMeterProvider() + + // Create a backend server to proxy to + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Simulate some processing time + time.Sleep(50 * time.Millisecond) + w.WriteHeader(http.StatusOK) + w.Write([]byte("backend response")) + })) + defer backend.Close() + + // Extract port from backend URL + _, port, _ := net.SplitHostPort(backend.Listener.Addr().String()) + + env := config{ + ContainerConcurrency: 10, + QueueServingPort: "8080", + QueueServingTLSPort: "8443", + UserPort: port, + RevisionTimeoutSeconds: 300, + ServingLoggingConfig: "", + ServingLoggingLevel: "info", + Observability: *observability.DefaultConfig(), + Env: Env{ + ServingNamespace: "test-namespace", + ServingConfiguration: "test-config", + ServingRevision: "test-revision", + ServingPod: "test-pod", + ServingPodIP: "10.0.0.1", + }, + } + + transport := buildTransport(env, tp, mp) + prober := func() bool { return true } + stats := netstats.NewRequestStats(time.Now()) + pendingRequests := atomic.Int32{} + + handler, drainer := mainHandler(env, transport, prober, stats, logger, mp, tp, &pendingRequests) + + t.Run("tracks pending requests correctly", func(t *testing.T) { + // Make a regular request + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "test.example.com" + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + }() + + // Give the request time to start + time.Sleep(10 * time.Millisecond) + + // Check that pending request counter was incremented + count := pendingRequests.Load() + if count != 1 { + t.Errorf("Expected 1 pending request, got %d", count) + } + + wg.Wait() + + // Check that counter was decremented after completion + if pendingRequests.Load() != 0 { + t.Errorf("Expected 0 pending requests after completion, got %d", pendingRequests.Load()) + } + }) + + t.Run("does not track probe requests", func(t *testing.T) { + // Make a probe request + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(network.ProbeHeaderName, network.ProbeHeaderValue) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + // Check that pending request counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected 0 pending requests for probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("handles concurrent requests", func(t *testing.T) { + numRequests := 5 + var wg sync.WaitGroup + wg.Add(numRequests) + + for i := range numRequests { + go func(i int) { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "test.example.com" + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + }(i) + } + + // Give requests time to start + time.Sleep(20 * time.Millisecond) + + // Check that multiple requests are being tracked + count := pendingRequests.Load() + if count <= 0 || count > int32(numRequests) { + t.Errorf("Expected pending requests between 1 and %d, got %d", numRequests, count) + } + + wg.Wait() + + // Check that all requests completed + if pendingRequests.Load() != 0 { + t.Errorf("Expected 0 pending requests after all completed, got %d", pendingRequests.Load()) + } + }) + + t.Run("drainer integration", func(t *testing.T) { + // Start a long-running request + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.Host = "test.example.com" + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + }() + + // Give request time to start + time.Sleep(10 * time.Millisecond) + + // Verify request is being tracked + if pendingRequests.Load() != 1 { + t.Errorf("Expected 1 pending request before drain, got %d", pendingRequests.Load()) + } + + // Call drain + drainer.Drain() + + // Wait for request to complete + wg.Wait() + + // Verify counter is back to 0 + if pendingRequests.Load() != 0 { + t.Errorf("Expected 0 pending requests after drain, got %d", pendingRequests.Load()) + } + }) +} + +func TestBuildBreaker(t *testing.T) { + logger := zap.NewNop().Sugar() + + t.Run("returns nil for unlimited concurrency", func(t *testing.T) { + env := config{ + ContainerConcurrency: 0, + } + breaker := buildBreaker(logger, env) + if breaker != nil { + t.Error("Expected nil breaker for unlimited concurrency") + } + }) + + t.Run("creates breaker with correct params", func(t *testing.T) { + env := config{ + ContainerConcurrency: 10, + } + breaker := buildBreaker(logger, env) + if breaker == nil { + t.Fatal("Expected non-nil breaker") + } + // The breaker should be configured with QueueDepth = 10 * ContainerConcurrency + // and MaxConcurrency = ContainerConcurrency + }) +} + +func TestBuildProbe(t *testing.T) { + logger := zap.NewNop().Sugar() + + t.Run("creates probe without HTTP2 auto-detection", func(t *testing.T) { + encodedProbe := `{"httpGet":{"path":"/health","port":8080}}` + probe := buildProbe(logger, encodedProbe, false, false) + if probe == nil { + t.Fatal("Expected non-nil probe") + } + }) + + t.Run("creates probe with HTTP2 auto-detection", func(t *testing.T) { + encodedProbe := `{"httpGet":{"path":"/health","port":8080}}` + probe := buildProbe(logger, encodedProbe, true, false) + if probe == nil { + t.Fatal("Expected non-nil probe") + } + }) +} + +func TestExists(t *testing.T) { + logger := zap.NewNop().Sugar() + + t.Run("returns true for existing file", func(t *testing.T) { + // Create a temporary file + tmpfile, err := os.CreateTemp("", "test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tmpfile.Name()) + + if !exists(logger, tmpfile.Name()) { + t.Error("Expected true for existing file") + } + }) + + t.Run("returns false for non-existent file", func(t *testing.T) { + if exists(logger, "/non/existent/file/path") { + t.Error("Expected false for non-existent file") + } + }) +} diff --git a/pkg/queue/sharedmain/handlers_test.go b/pkg/queue/sharedmain/handlers_test.go new file mode 100644 index 000000000000..b035c6eee373 --- /dev/null +++ b/pkg/queue/sharedmain/handlers_test.go @@ -0,0 +1,359 @@ +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sharedmain + +import ( + "context" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "go.uber.org/zap" + netheader "knative.dev/networking/pkg/http/header" + pkghandler "knative.dev/pkg/network/handlers" + "knative.dev/serving/pkg/queue" +) + +func TestDrainCompleteEndpoint(t *testing.T) { + logger := zap.NewNop().Sugar() + drainer := &pkghandler.Drainer{} + + t.Run("returns 200 when no pending requests", func(t *testing.T) { + pendingRequests := atomic.Int32{} + pendingRequests.Store(0) + + handler := adminHandler(context.Background(), logger, drainer, &pendingRequests) + + req := httptest.NewRequest(http.MethodGet, "/drain-complete", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + if w.Body.String() != "drained" { + t.Errorf("Expected body 'drained', got %s", w.Body.String()) + } + }) + + t.Run("returns 503 when requests are pending", func(t *testing.T) { + pendingRequests := atomic.Int32{} + pendingRequests.Store(5) + + handler := adminHandler(context.Background(), logger, drainer, &pendingRequests) + + req := httptest.NewRequest(http.MethodGet, "/drain-complete", nil) + w := httptest.NewRecorder() + + handler.ServeHTTP(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("Expected status 503, got %d", w.Code) + } + if w.Body.String() != "pending requests: 5" { + t.Errorf("Expected body 'pending requests: 5', got %s", w.Body.String()) + } + }) +} + +func TestRequestQueueDrainHandler(t *testing.T) { + logger := zap.NewNop().Sugar() + + t.Run("handles drain request when context is done", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + drainer := &pkghandler.Drainer{ + QuietPeriod: 100 * time.Millisecond, + Inner: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + pendingRequests := atomic.Int32{} + + handler := adminHandler(ctx, logger, drainer, &pendingRequests) + + // Cancel context to simulate TERM signal + cancel() + + req := httptest.NewRequest(http.MethodPost, queue.RequestQueueDrainPath, nil) + w := httptest.NewRecorder() + + // This should call drainer.Drain() and return immediately + handler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + + // Verify the drainer is in draining state by sending a probe + probeReq := httptest.NewRequest(http.MethodGet, "/", nil) + probeReq.Header.Set("User-Agent", "kube-probe/1.0") + probeW := httptest.NewRecorder() + drainer.ServeHTTP(probeW, probeReq) + + // Should return 503 because drainer is draining + if probeW.Code != http.StatusServiceUnavailable { + t.Errorf("Expected probe to return 503 during drain, got %d", probeW.Code) + } + }) + + t.Run("resets drainer after timeout when context not done", func(t *testing.T) { + ctx := context.Background() + drainer := &pkghandler.Drainer{ + QuietPeriod: 100 * time.Millisecond, + Inner: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }), + } + pendingRequests := atomic.Int32{} + + handler := adminHandler(ctx, logger, drainer, &pendingRequests) + + req := httptest.NewRequest(http.MethodPost, queue.RequestQueueDrainPath, nil) + w := httptest.NewRecorder() + + // Start the drain in a goroutine since Drain() blocks + done := make(chan bool) + go func() { + handler.ServeHTTP(w, req) + done <- true + }() + + // Give it time to start draining + time.Sleep(50 * time.Millisecond) + + // Check that drainer is draining + probeReq := httptest.NewRequest(http.MethodGet, "/", nil) + probeReq.Header.Set("User-Agent", "kube-probe/1.0") + probeW := httptest.NewRecorder() + drainer.ServeHTTP(probeW, probeReq) + + if probeW.Code != http.StatusServiceUnavailable { + t.Errorf("Expected probe to return 503 during drain, got %d", probeW.Code) + } + + // Wait for the reset to happen (after 1 second) + time.Sleep(1100 * time.Millisecond) + + // Check that drainer has been reset and is no longer draining + probeReq2 := httptest.NewRequest(http.MethodGet, "/", nil) + probeReq2.Header.Set("User-Agent", "kube-probe/1.0") + probeW2 := httptest.NewRecorder() + drainer.ServeHTTP(probeW2, probeReq2) + + // Should return 200 because drainer was reset + if probeW2.Code != http.StatusOK { + t.Errorf("Expected probe to return 200 after reset, got %d", probeW2.Code) + } + + // Clean up + select { + case <-done: + case <-time.After(2 * time.Second): + t.Error("Handler did not complete in time") + } + }) +} + +func TestWithRequestCounter(t *testing.T) { + pendingRequests := atomic.Int32{} + + // Create a test handler that we'll wrap + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Sleep briefly to ensure counter is incremented + time.Sleep(10 * time.Millisecond) + w.WriteHeader(http.StatusOK) + }) + + wrappedHandler := withRequestCounter(baseHandler, &pendingRequests) + + t.Run("increments counter for regular requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + wrappedHandler.ServeHTTP(w, req) + }() + + // Give the request time to start + time.Sleep(5 * time.Millisecond) + + // Check that counter was incremented + if pendingRequests.Load() != 1 { + t.Errorf("Expected pending requests to be 1, got %d", pendingRequests.Load()) + } + + wg.Wait() + + // Check that counter was decremented after request completed + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to be 0 after completion, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(netheader.ProbeKey, netheader.ProbeValue) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for kube-probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "kube-probe/1.27") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for kube-probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for Activator probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "Knative-Activator-Probe") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for Activator probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for Autoscaling probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "Knative-Autoscaling-Probe") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for Autoscaling probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for QueueProxy probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "Knative-Queue-Proxy-Probe") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for QueueProxy probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("skips counter for Ingress probe requests", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("User-Agent", "Knative-Ingress-Probe") + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + // Check that counter was not incremented + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to remain 0 for Ingress probe, got %d", pendingRequests.Load()) + } + }) + + t.Run("handles concurrent requests correctly", func(t *testing.T) { + // Reset counter + pendingRequests.Store(0) + + numRequests := 10 + var wg sync.WaitGroup + wg.Add(numRequests) + + for range numRequests { + go func() { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + wrappedHandler.ServeHTTP(w, req) + }() + } + + // Give requests time to start + time.Sleep(5 * time.Millisecond) + + // Check that all requests are being tracked + count := pendingRequests.Load() + if count <= 0 || count > int32(numRequests) { + t.Errorf("Expected pending requests to be between 1 and %d, got %d", numRequests, count) + } + + wg.Wait() + + // Check that counter returned to 0 + if pendingRequests.Load() != 0 { + t.Errorf("Expected pending requests to be 0 after all completed, got %d", pendingRequests.Load()) + } + }) +} + +func TestWithFullDuplex(t *testing.T) { + logger := zap.NewNop().Sugar() + + baseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + t.Run("passes through when disabled", func(t *testing.T) { + wrappedHandler := withFullDuplex(baseHandler, false, logger) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) + + t.Run("enables full duplex when configured", func(t *testing.T) { + wrappedHandler := withFullDuplex(baseHandler, true, logger) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + w := httptest.NewRecorder() + + wrappedHandler.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", w.Code) + } + }) +} diff --git a/pkg/queue/sharedmain/main.go b/pkg/queue/sharedmain/main.go index 5c1b31a65ccf..37d28b2b0f4b 100644 --- a/pkg/queue/sharedmain/main.go +++ b/pkg/queue/sharedmain/main.go @@ -25,6 +25,7 @@ import ( "net/http" "os" "strconv" + "sync/atomic" "time" "github.com/kelseyhightower/envconfig" @@ -48,6 +49,7 @@ import ( "knative.dev/serving/pkg/observability" "knative.dev/serving/pkg/queue" "knative.dev/serving/pkg/queue/certificate" + "knative.dev/serving/pkg/queue/drain" "knative.dev/serving/pkg/queue/readiness" ) @@ -58,7 +60,7 @@ const ( // Duration the /wait-for-drain handler should wait before returning. // This is to give networking a little bit more time to remove the pod // from its configuration and propagate that to all loadbalancers and nodes. - drainSleepDuration = 30 * time.Second + drainSleepDuration = 15 * time.Second // certPath is the path for the server certificate mounted by queue-proxy. certPath = queue.CertDirectory + "/" + certificates.CertName @@ -157,6 +159,8 @@ func Main(opts ...Option) error { d := Defaults{ Ctx: signals.NewContext(), } + pendingRequests := atomic.Int32{} + pendingRequests.Store(0) // Parse the environment. env := config{ @@ -231,9 +235,8 @@ func Main(opts ...Option) error { // Enable TLS when certificate is mounted. tlsEnabled := exists(logger, certPath) && exists(logger, keyPath) - - mainHandler, drainer := mainHandler(env, d.Transport, probe, stats, logger, mp, tp) - adminHandler := adminHandler(d.Ctx, logger, drainer) + mainHandler, drainer := mainHandler(env, d.Transport, probe, stats, logger, mp, tp, &pendingRequests) + adminHandler := adminHandler(d.Ctx, logger, drainer, &pendingRequests) // Enable TLS server when activator server certs are mounted. // At this moment activator with TLS does not disable HTTP. @@ -271,6 +274,10 @@ func Main(opts ...Option) error { logger.Info("Starting queue-proxy") + // Clean up any stale drain signal files from previous runs + os.Remove(drain.DrainStartedFile) + os.Remove(drain.DrainCompleteFile) + errCh := make(chan error) for name, server := range httpServers { go func(name string, s *http.Server) { @@ -304,9 +311,29 @@ func Main(opts ...Option) error { return err case <-d.Ctx.Done(): logger.Info("Received TERM signal, attempting to gracefully shutdown servers.") - logger.Infof("Sleeping %v to allow K8s propagation of non-ready state", drainSleepDuration) drainer.Drain() + // Wait on active requests to complete. This is done explicitly + // to avoid closing any connections which have been highjacked, + // as in net/http `.Shutdown` would do so ungracefully. + // See https://github.com/golang/go/issues/17721 + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + logger.Infof("Drain: waiting for %d pending requests to complete", pendingRequests.Load()) + WaitOnPendingRequests: + for range ticker.C { + if pendingRequests.Load() <= 0 { + logger.Infof("Drain: all pending requests completed") + break WaitOnPendingRequests + } + } + + // Write drain-complete signal file after draining is done + // This signals to user containers that queue-proxy has finished draining + if err := os.WriteFile(drain.DrainCompleteFile, []byte(""), 0o600); err != nil { + logger.Errorw("Failed to write drain-complete signal file", zap.Error(err)) + } + for name, srv := range httpServers { logger.Info("Shutting down server: ", name) if err := srv.Shutdown(context.Background()); err != nil { diff --git a/pkg/queue/sharedmain/shutdown_integration_test.go b/pkg/queue/sharedmain/shutdown_integration_test.go new file mode 100644 index 000000000000..5c4bfe0e7444 --- /dev/null +++ b/pkg/queue/sharedmain/shutdown_integration_test.go @@ -0,0 +1,403 @@ +//go:build integration +// +build integration + +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sharedmain + +import ( + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "go.uber.org/zap/zaptest" +) + +// TestShutdownCoordination_NormalSequence tests the normal shutdown sequence +// where queue-proxy writes drain signals and user container waits appropriately +func TestShutdownCoordination_NormalSequence(t *testing.T) { + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + drainStarted := filepath.Join(drainDir, "drain-started") + drainComplete := filepath.Join(drainDir, "drain-complete") + + // Simulate queue-proxy PreStop + queueProxyPreStop := func() { + if err := os.WriteFile(drainStarted, []byte(""), 0600); err != nil { + t.Errorf("Failed to write drain-started: %v", err) + } + } + + // Simulate user container PreStop with exponential backoff + userPreStopStarted := make(chan struct{}) + userPreStopCompleted := make(chan struct{}) + + go func() { + close(userPreStopStarted) + // Simulate the actual PreStop script logic + for _, delay := range []int{1, 2, 4, 8} { + if _, err := os.Stat(drainStarted); err == nil { + // drain-started exists, wait for drain-complete + for i := 0; i < 30; i++ { // Max 30 seconds wait + if _, err := os.Stat(drainComplete); err == nil { + close(userPreStopCompleted) + return + } + time.Sleep(1 * time.Second) + } + } + time.Sleep(time.Duration(delay) * time.Second) + } + // Exit after retries + close(userPreStopCompleted) + }() + + // Wait for user PreStop to start + <-userPreStopStarted + + // Simulate some delay before queue-proxy writes the file + time.Sleep(500 * time.Millisecond) + + // Execute queue-proxy PreStop + queueProxyPreStop() + + // Simulate queue-proxy draining and writing complete signal + time.Sleep(2 * time.Second) + if err := os.WriteFile(drainComplete, []byte(""), 0600); err != nil { + t.Fatal(err) + } + + // Wait for user PreStop to complete + select { + case <-userPreStopCompleted: + // Success + case <-time.After(20 * time.Second): + t.Fatal("User PreStop did not complete in time") + } + + // Verify both files exist + if _, err := os.Stat(drainStarted); os.IsNotExist(err) { + t.Error("drain-started file was not created") + } + if _, err := os.Stat(drainComplete); os.IsNotExist(err) { + t.Error("drain-complete file was not created") + } +} + +// TestShutdownCoordination_QueueProxyCrash tests behavior when queue-proxy +// crashes or fails to write the drain-started signal +func TestShutdownCoordination_QueueProxyCrash(t *testing.T) { + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + drainStarted := filepath.Join(drainDir, "drain-started") + + // Simulate user container PreStop without queue-proxy creating the file + userPreStopCompleted := make(chan struct{}) + startTime := time.Now() + + go func() { + // Simulate the actual PreStop script logic + for _, delay := range []int{1, 2, 4, 8} { + if _, err := os.Stat(drainStarted); err == nil { + t.Error("drain-started should not exist in crash scenario") + } + time.Sleep(time.Duration(delay) * time.Second) + } + // Should exit after retries + close(userPreStopCompleted) + }() + + // Wait for user PreStop to complete + select { + case <-userPreStopCompleted: + elapsed := time.Since(startTime) + // Should complete after 1+2+4+8 = 15 seconds + if elapsed < 14*time.Second || elapsed > 16*time.Second { + t.Errorf("PreStop took %v, expected ~15s", elapsed) + } + case <-time.After(20 * time.Second): + t.Fatal("User PreStop did not complete after retries") + } +} + +// TestShutdownCoordination_HighLoad tests shutdown under high request load +func TestShutdownCoordination_HighLoad(t *testing.T) { + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + drainStarted := filepath.Join(drainDir, "drain-started") + drainComplete := filepath.Join(drainDir, "drain-complete") + + // Simulate active requests + var pendingRequests int32 = 100 + var wg sync.WaitGroup + + // Simulate queue-proxy handling requests + for i := 0; i < 100; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + // Simulate request processing + time.Sleep(time.Duration(100+id*10) * time.Millisecond) + atomic.AddInt32(&pendingRequests, -1) + }(i) + } + + // Start shutdown sequence + go func() { + // Write drain-started immediately + if err := os.WriteFile(drainStarted, []byte(""), 0600); err != nil { + t.Errorf("Failed to write drain-started: %v", err) + } + + // Wait for all requests to complete + for atomic.LoadInt32(&pendingRequests) > 0 { + time.Sleep(100 * time.Millisecond) + } + + // Write drain-complete after all requests are done + if err := os.WriteFile(drainComplete, []byte(""), 0600); err != nil { + t.Errorf("Failed to write drain-complete: %v", err) + } + }() + + // Simulate user container waiting + userPreStopCompleted := make(chan struct{}) + go func() { + for _, delay := range []int{1, 2, 4, 8} { + if _, err := os.Stat(drainStarted); err == nil { + // Wait for drain-complete + for i := 0; i < 30; i++ { + if _, err := os.Stat(drainComplete); err == nil { + close(userPreStopCompleted) + return + } + time.Sleep(1 * time.Second) + } + } + time.Sleep(time.Duration(delay) * time.Second) + } + t.Error("User PreStop exited without seeing drain-complete") + }() + + // Wait for all requests to complete + wg.Wait() + + // Ensure user PreStop completes + select { + case <-userPreStopCompleted: + // Verify no requests remain + if atomic.LoadInt32(&pendingRequests) != 0 { + t.Errorf("Requests remaining: %d", pendingRequests) + } + case <-time.After(30 * time.Second): + t.Fatal("User PreStop did not complete under load") + } +} + +// TestShutdownCoordination_FilePermissions tests behavior with file system issues +func TestShutdownCoordination_FilePermissions(t *testing.T) { + // Skip if not running as root (can't test permission issues properly) + if os.Geteuid() == 0 { + t.Skip("Cannot test permission issues as root") + } + + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + // Make directory read-only + if err := os.Chmod(drainDir, 0555); err != nil { + t.Fatal(err) + } + defer os.Chmod(drainDir, 0755) // Restore for cleanup + + drainStarted := filepath.Join(drainDir, "drain-started") + + // Try to write drain-started (should fail) + err := os.WriteFile(drainStarted, []byte(""), 0600) + if err == nil { + t.Error("Expected write to fail with read-only directory") + } + + // User PreStop should still complete after retries + userPreStopCompleted := make(chan struct{}) + go func() { + for _, delay := range []int{1, 2, 4, 8} { + if _, err := os.Stat(drainStarted); err == nil { + t.Error("File should not exist with permission issues") + } + time.Sleep(time.Duration(delay) * time.Millisecond) // Use ms for faster test + } + close(userPreStopCompleted) + }() + + select { + case <-userPreStopCompleted: + // Success - PreStop completed despite permission issues + case <-time.After(5 * time.Second): + t.Fatal("User PreStop did not complete with permission issues") + } +} + +// TestShutdownCoordination_RaceCondition tests for race conditions +// between queue-proxy and user container PreStop hooks +func TestShutdownCoordination_RaceCondition(t *testing.T) { + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + + // Run multiple iterations to catch race conditions + for i := 0; i < 50; i++ { + // Clean up from previous iteration + os.RemoveAll(drainDir) + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + drainStarted := filepath.Join(drainDir, "drain-started") + drainComplete := filepath.Join(drainDir, "drain-complete") + + var wg sync.WaitGroup + wg.Add(2) + + // Queue-proxy PreStop and shutdown + go func() { + defer wg.Done() + // Random delay to create race conditions + time.Sleep(time.Duration(i%10) * time.Millisecond) + os.WriteFile(drainStarted, []byte(""), 0600) + time.Sleep(time.Duration(i%5) * time.Millisecond) + os.WriteFile(drainComplete, []byte(""), 0600) + }() + + // User container PreStop + completed := make(chan bool, 1) + go func() { + defer wg.Done() + timeout := time.After(20 * time.Second) + for _, delay := range []int{1, 2, 4, 8} { + select { + case <-timeout: + completed <- false + return + default: + } + + if _, err := os.Stat(drainStarted); err == nil { + // Wait for complete + for j := 0; j < 30; j++ { + if _, err := os.Stat(drainComplete); err == nil { + completed <- true + return + } + time.Sleep(100 * time.Millisecond) + } + } + time.Sleep(time.Duration(delay) * time.Millisecond) + } + completed <- true // Exit after retries + }() + + // Wait for both to complete + wg.Wait() + + if !<-completed { + t.Errorf("Iteration %d: User PreStop timed out", i) + } + } +} + +// TestShutdownCoordination_LongRunningRequests tests behavior with +// requests that take longer than the grace period +func TestShutdownCoordination_LongRunningRequests(t *testing.T) { + logger := zaptest.NewLogger(t).Sugar() + + tmpDir := t.TempDir() + drainDir := filepath.Join(tmpDir, "knative") + if err := os.MkdirAll(drainDir, 0755); err != nil { + t.Fatal(err) + } + + drainStarted := filepath.Join(drainDir, "drain-started") + drainComplete := filepath.Join(drainDir, "drain-complete") + + // Simulate a long-running request + requestComplete := make(chan struct{}) + go func() { + logger.Info("Starting long-running request") + time.Sleep(10 * time.Second) // Longer than typical drain timeout + close(requestComplete) + logger.Info("Long-running request completed") + }() + + // Start shutdown + go func() { + os.WriteFile(drainStarted, []byte(""), 0600) + + // In real scenario, this would wait for requests or timeout + select { + case <-requestComplete: + logger.Info("Request completed, writing drain-complete") + case <-time.After(5 * time.Second): + logger.Info("Timeout waiting for request, writing drain-complete anyway") + } + + os.WriteFile(drainComplete, []byte(""), 0600) + }() + + // User container should still proceed + userExited := make(chan struct{}) + go func() { + for _, delay := range []int{1, 2, 4, 8} { + if _, err := os.Stat(drainStarted); err == nil { + // Wait for complete with timeout + for i := 0; i < 10; i++ { + if _, err := os.Stat(drainComplete); err == nil { + close(userExited) + return + } + time.Sleep(1 * time.Second) + } + } + time.Sleep(time.Duration(delay) * time.Second) + } + close(userExited) + }() + + select { + case <-userExited: + // User container should exit even with long-running request + case <-time.After(30 * time.Second): + t.Fatal("User container did not exit with long-running request") + } +} diff --git a/pkg/reconciler/revision/resources/deploy.go b/pkg/reconciler/revision/resources/deploy.go index fff3847cb7eb..9f103527bba6 100644 --- a/pkg/reconciler/revision/resources/deploy.go +++ b/pkg/reconciler/revision/resources/deploy.go @@ -30,6 +30,7 @@ import ( v1 "knative.dev/serving/pkg/apis/serving/v1" "knative.dev/serving/pkg/networking" "knative.dev/serving/pkg/queue" + "knative.dev/serving/pkg/queue/drain" "knative.dev/serving/pkg/reconciler/revision/config" "knative.dev/serving/pkg/reconciler/revision/resources/names" @@ -98,18 +99,17 @@ var ( ReadOnly: true, } - // This PreStop hook is actually calling an endpoint on the queue-proxy - // because of the way PreStop hooks are called by kubelet. We use this - // to block the user-container from exiting before the queue-proxy is ready - // to exit so we can guarantee that there are no more requests in flight. - userLifecycle = &corev1.Lifecycle{ - PreStop: &corev1.LifecycleHandler{ - HTTPGet: &corev1.HTTPGetAction{ - Port: intstr.FromInt(networking.QueueAdminPort), - Path: queue.RequestQueueDrainPath, - }, + varDrainVolume = corev1.Volume{ + Name: "knative-drain-signal", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, }, } + + varDrainVolumeMount = corev1.VolumeMount{ + Name: "knative-drain-signal", + MountPath: "/var/run/knative", + } ) func addToken(tokenVolume *corev1.Volume, filename string, audience string, expiry *int64) { @@ -173,7 +173,9 @@ func makePodSpec(rev *v1.Revision, cfg *config.Config) (*corev1.PodSpec, error) return nil, fmt.Errorf("failed to create queue-proxy container: %w", err) } + // Add drain volume for signaling between containers var extraVolumes []corev1.Volume + extraVolumes = append(extraVolumes, varDrainVolume) podInfoFeature, podInfoExists := rev.Annotations[apiconfig.QueueProxyPodInfoFeatureKey] @@ -261,7 +263,7 @@ func BuildUserContainers(rev *v1.Revision) []corev1.Container { func makeContainer(container corev1.Container, rev *v1.Revision) corev1.Container { // Adding or removing an overwritten corev1.Container field here? Don't forget to // update the fieldmasks / validations in pkg/apis/serving - container.Lifecycle = userLifecycle + container.Lifecycle = buildLifecycleWithDrainWait(container.Lifecycle) container.Env = append(container.Env, getKnativeEnvVar(rev)...) // Explicitly disable stdin and tty allocation @@ -279,6 +281,9 @@ func makeContainer(container corev1.Container, rev *v1.Revision) corev1.Containe } } + // Mount the drain volume for PreStop hook to check drain signal + container.VolumeMounts = append(container.VolumeMounts, varDrainVolumeMount) + return container } @@ -294,6 +299,50 @@ func makeServingContainer(servingContainer corev1.Container, rev *v1.Revision) c return container } +// buildLifecycleWithDrainWait preserves any existing pre-stop hooks and adds the drain wait +func buildLifecycleWithDrainWait(existingLifecycle *corev1.Lifecycle) *corev1.Lifecycle { + // If there's an existing lifecycle with a pre-stop hook, preserve it + if existingLifecycle != nil && existingLifecycle.PreStop != nil { + // Convert existing pre-stop to exec command if needed + var existingCommand string + if existingLifecycle.PreStop.Exec != nil { + existingCommand = strings.Join(existingLifecycle.PreStop.Exec.Command, " ") + } else if existingLifecycle.PreStop.HTTPGet != nil { + // Convert HTTP GET to curl command + port := existingLifecycle.PreStop.HTTPGet.Port.String() + path := existingLifecycle.PreStop.HTTPGet.Path + if path == "" { + path = "/" + } + existingCommand = fmt.Sprintf("curl -f http://localhost:%s%s", port, path) + } + + // Combine: run existing hook first, then wait for drain + return &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", + drain.BuildDrainWaitScript(existingCommand), + }, + }, + }, + } + } + + // No existing lifecycle, just add the drain wait + return &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", + drain.BuildDrainWaitScript(""), + }, + }, + }, + } +} + // BuildPodSpec creates a PodSpec from the given revision and containers. // cfg can be passed as nil if not within revision reconciliation context. func BuildPodSpec(rev *v1.Revision, containers []corev1.Container, cfg *config.Config) *corev1.PodSpec { diff --git a/pkg/reconciler/revision/resources/deploy_lifecycle_test.go b/pkg/reconciler/revision/resources/deploy_lifecycle_test.go new file mode 100644 index 000000000000..3551fcef6338 --- /dev/null +++ b/pkg/reconciler/revision/resources/deploy_lifecycle_test.go @@ -0,0 +1,97 @@ +/* +Copyright 2024 The Knative Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package resources + +import ( + "testing" + + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "knative.dev/serving/pkg/queue/drain" +) + +func TestBuildLifecycleWithDrainWait(t *testing.T) { + // Use the same logic as production code + drainCommand := drain.BuildDrainWaitScript("") + + tests := []struct { + name string + existing *corev1.Lifecycle + want []string + }{ + { + name: "no existing lifecycle", + existing: nil, + want: []string{"/bin/sh", "-c", drainCommand}, + }, + { + name: "existing exec command", + existing: &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{"/app/cleanup.sh"}, + }, + }, + }, + want: []string{"/bin/sh", "-c", drain.BuildDrainWaitScript("/app/cleanup.sh")}, + }, + { + name: "existing HTTP GET", + existing: &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(8080), + Path: "/shutdown", + }, + }, + }, + want: []string{"/bin/sh", "-c", drain.BuildDrainWaitScript("curl -f http://localhost:8080/shutdown")}, + }, + { + name: "existing HTTP GET without path", + existing: &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Port: intstr.FromInt(9090), + }, + }, + }, + want: []string{"/bin/sh", "-c", drain.BuildDrainWaitScript("curl -f http://localhost:9090/")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := buildLifecycleWithDrainWait(tt.existing) + + if result == nil || result.PreStop == nil || result.PreStop.Exec == nil { + t.Fatal("Expected lifecycle with exec prestop handler") + } + + gotCommand := result.PreStop.Exec.Command + if len(gotCommand) != len(tt.want) { + t.Errorf("Command length mismatch: got %d, want %d", len(gotCommand), len(tt.want)) + } + + for i, cmd := range tt.want { + if i < len(gotCommand) && gotCommand[i] != cmd { + t.Errorf("Command[%d]: got %q, want %q", i, gotCommand[i], cmd) + } + } + }) + } +} diff --git a/pkg/reconciler/revision/resources/deploy_test.go b/pkg/reconciler/revision/resources/deploy_test.go index 897428fff91e..012adc9406fa 100644 --- a/pkg/reconciler/revision/resources/deploy_test.go +++ b/pkg/reconciler/revision/resources/deploy_test.go @@ -56,10 +56,14 @@ var ( Name: servingContainerName, Image: "busybox", Ports: buildContainerPorts(v1.DefaultUserPort), - Lifecycle: userLifecycle, + Lifecycle: buildLifecycleWithDrainWait(nil), TerminationMessagePolicy: corev1.TerminationMessageFallbackToLogsOnError, Stdin: false, TTY: false, + VolumeMounts: []corev1.VolumeMount{{ + Name: "knative-drain-signal", + MountPath: "/var/run/knative", + }}, Env: []corev1.EnvVar{{ Name: "PORT", Value: "8080", @@ -87,9 +91,20 @@ var ( }}, }, }, - PeriodSeconds: 0, + PeriodSeconds: 1, + FailureThreshold: 1, }, SecurityContext: queueSecurityContext, + Lifecycle: &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", + "touch " + "/var/run/knative/drain-started", // Using string directly to match production + }, + }, + }, + }, Env: []corev1.EnvVar{{ Name: "SERVING_NAMESPACE", Value: "foo", // matches namespace @@ -143,7 +158,7 @@ var ( Value: system.Namespace(), }, { Name: "SERVING_READINESS_PROBE", - Value: fmt.Sprintf(`{"tcpSocket":{"port":%d,"host":"127.0.0.1"}}`, v1.DefaultUserPort), + Value: fmt.Sprintf(`{"tcpSocket":{"port":%d,"host":"127.0.0.1"},"failureThreshold":1}`, v1.DefaultUserPort), }, { Name: "HOST_IP", ValueFrom: &corev1.EnvVarSource{ @@ -165,11 +180,23 @@ var ( Name: "OBSERVABILITY_CONFIG", Value: `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{}}`, }}, + VolumeMounts: []corev1.VolumeMount{{ + Name: "knative-drain-signal", + MountPath: "/var/run/knative", + }}, } defaultPodSpec = &corev1.PodSpec{ TerminationGracePeriodSeconds: ptr.Int64(45), EnableServiceLinks: ptr.Bool(false), + Volumes: []corev1.Volume{ + { + Name: "knative-drain-signal", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, } defaultPodAntiAffinityRules = &corev1.PodAntiAffinity{ @@ -253,10 +280,14 @@ func defaultSidecarContainer(containerName string) *corev1.Container { return &corev1.Container{ Name: containerName, Image: "ubuntu", - Lifecycle: userLifecycle, + Lifecycle: buildLifecycleWithDrainWait(nil), TerminationMessagePolicy: corev1.TerminationMessageFallbackToLogsOnError, Stdin: false, TTY: false, + VolumeMounts: []corev1.VolumeMount{{ + Name: "knative-drain-signal", + MountPath: "/var/run/knative", + }}, Env: []corev1.EnvVar{{ Name: "K_REVISION", Value: "bar", @@ -304,6 +335,11 @@ func queueContainer(opts ...containerOption) corev1.Container { return container(defaultQueueContainer.DeepCopy(), opts...) } +// Helper to get default probe JSON with failureThreshold +func defaultProbeJSON(port int32) string { + return fmt.Sprintf(`{"tcpSocket":{"port":%d,"host":"127.0.0.1"},"failureThreshold":1}`, port) +} + func withEnvVar(name, value string) containerOption { return func(container *corev1.Container) { for i, envVar := range container.Env { @@ -420,7 +456,7 @@ func withAppendedTokenVolumes(appended []appendTokenVolume) podSpecOption { Audience: a.audience, }, } - tokenVolume.VolumeSource.Projected.Sources = append(tokenVolume.VolumeSource.Projected.Sources, *token) + tokenVolume.Projected.Sources = append(tokenVolume.Projected.Sources, *token) } ps.Volumes = append(ps.Volumes, *tokenVolume) } @@ -448,8 +484,8 @@ func appsv1deployment(opts ...deploymentOption) *appsv1.Deployment { func revision(name, ns string, opts ...RevisionOption) *v1.Revision { revision := defaultRevision() - revision.ObjectMeta.Name = name - revision.ObjectMeta.Namespace = ns + revision.Name = name + revision.Namespace = ns for _, option := range opts { option(revision) } @@ -475,7 +511,7 @@ func withoutLabels(revision *v1.Revision) { func withOwnerReference(name string) RevisionOption { return func(revision *v1.Revision) { - revision.ObjectMeta.OwnerReferences = []metav1.OwnerReference{{ + revision.OwnerReferences = []metav1.OwnerReference{{ APIVersion: v1.SchemeGroupVersion.String(), Kind: "Configuration", Name: name, @@ -582,7 +618,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("USER_PORT", "8888"), - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"},"failureThreshold":1}`), ), }), }, { @@ -629,7 +665,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("USER_PORT", "8888"), - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"},"failureThreshold":1}`), ), }, withPrependedVolumes(corev1.Volume{ Name: "asdf", @@ -656,7 +692,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, func(p *corev1.PodSpec) { p.EnableServiceLinks = ptr.Bool(true) }), @@ -683,7 +721,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, func(p *corev1.PodSpec) { p.EnableServiceLinks = nil }), @@ -789,7 +829,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, ), }, { @@ -840,7 +882,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, ), }, { @@ -863,7 +907,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, ), }, { @@ -930,7 +976,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox@sha256:deadbeef" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}`), ), }), }, { @@ -951,7 +997,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox@sha256:deadbeef" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"grpc":{"port":8080,"service":null}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"grpc":{"port":8080,"service":null},"failureThreshold":1}`), ), }), }, { @@ -972,7 +1018,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox@sha256:deadbeef" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":12345,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":12345,"host":"127.0.0.1"},"failureThreshold":1}`), ), }), }, { @@ -995,7 +1041,7 @@ func TestMakePodSpec(t *testing.T) { container.ReadinessProbe = withExecReadinessProbe([]string{"echo", "hello"}) }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"},"failureThreshold":1}`), ), }), }, { @@ -1030,7 +1076,9 @@ func TestMakePodSpec(t *testing.T) { }, }), ), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }), }, { name: "with tcp liveness probe", @@ -1062,7 +1110,9 @@ func TestMakePodSpec(t *testing.T) { }, }), ), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }), }, { name: "with HTTP startup probe", @@ -1095,7 +1145,9 @@ func TestMakePodSpec(t *testing.T) { }, }), ), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }), }, { name: "with TCP startup probe", @@ -1125,7 +1177,9 @@ func TestMakePodSpec(t *testing.T) { TCPSocket: &corev1.TCPSocketAction{}, }), ), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }), }, { name: "complex pod spec", @@ -1158,6 +1212,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("SERVING_SERVICE", "svc"), + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), ), }), }, { @@ -1276,7 +1331,7 @@ func TestMakePodSpec(t *testing.T) { queueContainer( withEnvVar("SERVING_SERVICE", "svc"), withEnvVar("USER_PORT", "8888"), - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8888,"host":"127.0.0.1"},"failureThreshold":1}`), ), }), }, { @@ -1302,7 +1357,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("USER_PORT", "8080"), - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"},"failureThreshold":1}`), ), }, func(p *corev1.PodSpec) { @@ -1334,11 +1389,11 @@ func TestMakePodSpec(t *testing.T) { []corev1.Container{ servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" - container.VolumeMounts = []corev1.VolumeMount{{ + container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ Name: varLogVolume.Name, MountPath: "/var/log", SubPathExpr: "$(K_INTERNAL_POD_NAMESPACE)_$(K_INTERNAL_POD_NAME)_" + servingContainerName, - }} + }) container.Env = append(container.Env, corev1.EnvVar{ Name: "K_INTERNAL_POD_NAME", @@ -1351,11 +1406,11 @@ func TestMakePodSpec(t *testing.T) { }), sidecarContainer(sidecarContainerName, func(c *corev1.Container) { c.Image = "ubuntu@sha256:deadbeef" - c.VolumeMounts = []corev1.VolumeMount{{ + c.VolumeMounts = append(c.VolumeMounts, corev1.VolumeMount{ Name: varLogVolume.Name, MountPath: "/var/log", SubPathExpr: "$(K_INTERNAL_POD_NAMESPACE)_$(K_INTERNAL_POD_NAME)_" + sidecarContainerName, - }} + }) c.Env = append(c.Env, corev1.EnvVar{ Name: "K_INTERNAL_POD_NAME", @@ -1367,7 +1422,7 @@ func TestMakePodSpec(t *testing.T) { }) }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"tcpSocket":{"port":8080,"host":"127.0.0.1"},"failureThreshold":1}`), withEnvVar("OBSERVABILITY_CONFIG", `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{},"EnableVarLogCollection":true}`), ), }, @@ -1416,10 +1471,10 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox@sha256:deadbeef" }), queueContainer(func(container *corev1.Container) { - container.VolumeMounts = []corev1.VolumeMount{{ + container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ Name: varTokenVolume.Name, MountPath: "/var/run/secrets/tokens", - }} + }) }), }, withAppendedTokenVolumes([]appendTokenVolume{{filename: "boo-srv", audience: "boo-srv", expires: 3600}}), @@ -1486,7 +1541,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("ENABLE_MULTI_CONTAINER_PROBES", "true"), - withEnvVar("SERVING_READINESS_PROBE", `[{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}},{"httpGet":{"path":"/","port":8090,"host":"127.0.0.1","scheme":"HTTP"}}]`), + withEnvVar("SERVING_READINESS_PROBE", `[{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1},{"httpGet":{"path":"/","port":8090,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}]`), ), }), }, { @@ -1525,7 +1580,7 @@ func TestMakePodSpec(t *testing.T) { ), queueContainer( withEnvVar("ENABLE_MULTI_CONTAINER_PROBES", "true"), - withEnvVar("SERVING_READINESS_PROBE", `[{"tcpSocket":{"port":8080,"host":"127.0.0.1"}}]`), + withEnvVar("SERVING_READINESS_PROBE", `[{"tcpSocket":{"port":8080,"host":"127.0.0.1"},"failureThreshold":1}]`), ), }), }, { @@ -1551,7 +1606,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, func(p *corev1.PodSpec) { p.Affinity = &corev1.Affinity{ @@ -1582,7 +1639,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, ), }, { @@ -1612,7 +1671,9 @@ func TestMakePodSpec(t *testing.T) { servingContainer(func(container *corev1.Container) { container.Image = "busybox@sha256:deadbeef" }), - queueContainer(), + queueContainer( + withEnvVar("SERVING_READINESS_PROBE", defaultProbeJSON(v1.DefaultUserPort)), + ), }, func(p *corev1.PodSpec) { p.Affinity = &corev1.Affinity{ @@ -1640,7 +1701,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}`), ), }, withRuntimeClass("gvisor")), }, { @@ -1667,7 +1728,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}`), ), }), }, { @@ -1695,7 +1756,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}`), ), }, withRuntimeClass("gvisor")), }, { @@ -1724,7 +1785,7 @@ func TestMakePodSpec(t *testing.T) { container.Image = "busybox" }), queueContainer( - withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"}}`), + withEnvVar("SERVING_READINESS_PROBE", `{"httpGet":{"path":"/","port":8080,"host":"127.0.0.1","scheme":"HTTP"},"failureThreshold":1}`), ), }, withRuntimeClass("kata")), }} diff --git a/pkg/reconciler/revision/resources/queue.go b/pkg/reconciler/revision/resources/queue.go index a5754f8e4a32..b3819cf388e6 100644 --- a/pkg/reconciler/revision/resources/queue.go +++ b/pkg/reconciler/revision/resources/queue.go @@ -40,6 +40,7 @@ import ( "knative.dev/serving/pkg/deployment" "knative.dev/serving/pkg/networking" "knative.dev/serving/pkg/queue" + "knative.dev/serving/pkg/queue/drain" "knative.dev/serving/pkg/queue/readiness" "knative.dev/serving/pkg/reconciler/revision/config" ) @@ -297,6 +298,13 @@ func makeQueueContainer(rev *v1.Revision, cfg *config.Config) (*corev1.Container }}, }, } + // Make queue proxy readiness probe more aggressive only if not user-defined + if queueProxyReadinessProbe.PeriodSeconds == 0 { + queueProxyReadinessProbe.PeriodSeconds = 1 + } + if queueProxyReadinessProbe.FailureThreshold == 0 { + queueProxyReadinessProbe.FailureThreshold = 1 + } } // Sidecar readiness probes @@ -356,6 +364,16 @@ func makeQueueContainer(rev *v1.Revision, cfg *config.Config) (*corev1.Container StartupProbe: nil, ReadinessProbe: queueProxyReadinessProbe, SecurityContext: queueSecurityContext, + Lifecycle: &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + Exec: &corev1.ExecAction{ + Command: []string{ + "/bin/sh", "-c", + drain.QueueProxyPreStopScript, + }, + }, + }, + }, Env: []corev1.EnvVar{{ Name: "SERVING_NAMESPACE", Value: rev.Namespace, @@ -439,6 +457,10 @@ func makeQueueContainer(rev *v1.Revision, cfg *config.Config) (*corev1.Container Name: "OBSERVABILITY_CONFIG", Value: string(o11yConfig), }}, + VolumeMounts: []corev1.VolumeMount{{ + Name: "knative-drain-signal", + MountPath: "/var/run/knative", + }}, } return c, nil @@ -470,6 +492,10 @@ func applyReadinessProbeDefaults(p *corev1.Probe, port int32) { p.GRPC.Port = port } + // Set aggressive defaults for faster failure detection + if p.FailureThreshold == 0 { + p.FailureThreshold = 1 // Mark unready immediately on failure + } if p.PeriodSeconds > 0 && p.TimeoutSeconds < 1 { p.TimeoutSeconds = 1 } diff --git a/pkg/reconciler/revision/resources/queue_test.go b/pkg/reconciler/revision/resources/queue_test.go index 1f7f05a4bc61..e30e350c896c 100644 --- a/pkg/reconciler/revision/resources/queue_test.go +++ b/pkg/reconciler/revision/resources/queue_test.go @@ -78,7 +78,7 @@ var ( defaults, _ = apicfg.NewDefaultsConfigFromMap(nil) ) -const testProbeJSONTemplate = `{"tcpSocket":{"port":%d,"host":"127.0.0.1"}}` +const testProbeJSONTemplate = `{"tcpSocket":{"port":%d,"host":"127.0.0.1"},"failureThreshold":1}` func TestMakeQueueContainer(t *testing.T) { tests := []struct { @@ -90,409 +90,411 @@ func TestMakeQueueContainer(t *testing.T) { dc deployment.Config fc apicfg.Features want corev1.Container - }{{ - name: "autoscaler single", - rev: revision("bar", "foo", - withContainers(containers), - withContainerConcurrency(1)), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "CONTAINER_CONCURRENCY": "1", - }) - }), - }, { - name: "custom readiness probe port", - rev: revision("bar", "foo", - withContainers([]corev1.Container{{ - Name: servingContainerName, - ReadinessProbe: &corev1.Probe{ - ProbeHandler: corev1.ProbeHandler{ - TCPSocket: &corev1.TCPSocketAction{ - Host: "127.0.0.1", - Port: intstr.FromInt(8087), + }{ + { + name: "autoscaler single", + rev: revision("bar", "foo", + withContainers(containers), + withContainerConcurrency(1)), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "CONTAINER_CONCURRENCY": "1", + }) + }), + }, { + name: "custom readiness probe port", + rev: revision("bar", "foo", + withContainers([]corev1.Container{{ + Name: servingContainerName, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Host: "127.0.0.1", + Port: intstr.FromInt(8087), + }, }, }, + Ports: []corev1.ContainerPort{{ + ContainerPort: 1955, + Name: string(netapi.ProtocolH2C), + }}, + }})), + dc: deployment.Config{ + QueueSidecarImage: "alpine", + }, + want: queueContainer(func(c *corev1.Container) { + c.Image = "alpine" + c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) + c.ReadinessProbe.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort + c.Env = env(map[string]string{ + "USER_PORT": "1955", + "QUEUE_SERVING_PORT": "8013", + }) + }), + }, { + name: "custom sidecar image, container port, protocol", + rev: revision("bar", "foo", + withContainers([]corev1.Container{{ + Name: servingContainerName, + ReadinessProbe: testProbe, + Ports: []corev1.ContainerPort{{ + ContainerPort: 1955, + Name: string(netapi.ProtocolH2C), + }}, + }})), + dc: deployment.Config{ + QueueSidecarImage: "alpine", + }, + want: queueContainer(func(c *corev1.Container) { + c.Image = "alpine" + c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) + c.ReadinessProbe.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort + c.Env = env(map[string]string{ + "USER_PORT": "1955", + "QUEUE_SERVING_PORT": "8013", + }) + }), + }, { + name: "service name in labels", + rev: revision("bar", "foo", + withContainers(containers), + func(revision *v1.Revision) { + revision.Labels = map[string]string{ + serving.ServiceLabelKey: "svc", + } + }), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "SERVING_SERVICE": "svc", + }) + }), + }, { + name: "config owner as env var, zero concurrency", + rev: revision("blah", "baz", + withContainers(containers), + withContainerConcurrency(0), + func(revision *v1.Revision) { + revision.OwnerReferences = []metav1.OwnerReference{{ + APIVersion: v1.SchemeGroupVersion.String(), + Kind: "Configuration", + Name: "the-parent-config-name", + Controller: ptr.Bool(true), + BlockOwnerDeletion: ptr.Bool(true), + }} + }), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "CONTAINER_CONCURRENCY": "0", + "SERVING_CONFIGURATION": "the-parent-config-name", + "SERVING_NAMESPACE": "baz", + "SERVING_REVISION": "blah", + }) + }), + }, { + name: "logging configuration as env var", + rev: revision("this", "log", + withContainers(containers)), + lc: logging.Config{ + LoggingConfig: "The logging configuration goes here", + LoggingLevel: map[string]zapcore.Level{ + "queueproxy": zapcore.ErrorLevel, }, - Ports: []corev1.ContainerPort{{ - ContainerPort: 1955, - Name: string(netapi.ProtocolH2C), - }}, - }})), - dc: deployment.Config{ - QueueSidecarImage: "alpine", - }, - want: queueContainer(func(c *corev1.Container) { - c.Image = "alpine" - c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) - c.ReadinessProbe.ProbeHandler.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort - c.Env = env(map[string]string{ - "USER_PORT": "1955", - "QUEUE_SERVING_PORT": "8013", - }) - }), - }, { - name: "custom sidecar image, container port, protocol", - rev: revision("bar", "foo", - withContainers([]corev1.Container{{ - Name: servingContainerName, - ReadinessProbe: testProbe, - Ports: []corev1.ContainerPort{{ - ContainerPort: 1955, - Name: string(netapi.ProtocolH2C), - }}, - }})), - dc: deployment.Config{ - QueueSidecarImage: "alpine", - }, - want: queueContainer(func(c *corev1.Container) { - c.Image = "alpine" - c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) - c.ReadinessProbe.ProbeHandler.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort - c.Env = env(map[string]string{ - "USER_PORT": "1955", - "QUEUE_SERVING_PORT": "8013", - }) - }), - }, { - name: "service name in labels", - rev: revision("bar", "foo", - withContainers(containers), - func(revision *v1.Revision) { - revision.Labels = map[string]string{ - serving.ServiceLabelKey: "svc", - } + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "SERVING_LOGGING_CONFIG": "The logging configuration goes here", + "SERVING_LOGGING_LEVEL": "error", + "SERVING_NAMESPACE": "log", + "SERVING_REVISION": "this", + }) }), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "SERVING_SERVICE": "svc", - }) - }), - }, { - name: "config owner as env var, zero concurrency", - rev: revision("blah", "baz", - withContainers(containers), - withContainerConcurrency(0), - func(revision *v1.Revision) { - revision.ObjectMeta.OwnerReferences = []metav1.OwnerReference{{ - APIVersion: v1.SchemeGroupVersion.String(), - Kind: "Configuration", - Name: "the-parent-config-name", - Controller: ptr.Bool(true), - BlockOwnerDeletion: ptr.Bool(true), - }} + }, { + name: "container concurrency 10", + rev: revision("bar", "foo", + withContainers(containers), + withContainerConcurrency(10)), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "CONTAINER_CONCURRENCY": "10", + }) }), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "CONTAINER_CONCURRENCY": "0", - "SERVING_CONFIGURATION": "the-parent-config-name", - "SERVING_NAMESPACE": "baz", - "SERVING_REVISION": "blah", - }) - }), - }, { - name: "logging configuration as env var", - rev: revision("this", "log", - withContainers(containers)), - lc: logging.Config{ - LoggingConfig: "The logging configuration goes here", - LoggingLevel: map[string]zapcore.Level{ - "queueproxy": zapcore.ErrorLevel, + }, { + name: "request log configuration as env var", + rev: revision("bar", "foo", + withContainers(containers)), + oc: observability.Config{ + RequestLogTemplate: "test template", + EnableProbeRequestLog: true, }, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "SERVING_LOGGING_CONFIG": "The logging configuration goes here", - "SERVING_LOGGING_LEVEL": "error", - "SERVING_NAMESPACE": "log", - "SERVING_REVISION": "this", - }) - }), - }, { - name: "container concurrency 10", - rev: revision("bar", "foo", - withContainers(containers), - withContainerConcurrency(10)), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "CONTAINER_CONCURRENCY": "10", - }) - }), - }, { - name: "request log configuration as env var", - rev: revision("bar", "foo", - withContainers(containers)), - oc: observability.Config{ - RequestLogTemplate: "test template", - EnableProbeRequestLog: true, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{},"requestLogTemplate":"test template","enableProbeRequestLog":true}`, - }) - }), - }, { - name: "disabled request log configuration as env var", - rev: revision("bar", "foo", - withContainers(containers)), - oc: observability.Config{ - RequestLogTemplate: "test template", - EnableProbeRequestLog: false, - EnableRequestLog: false, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{},"requestLogTemplate":"test template"}`, - }) - }), - }, { - name: "request metrics backend as env var", - rev: revision("bar", "foo", - withContainers(containers)), - oc: observability.Config{ - RequestMetrics: observability.MetricsConfig{ - Protocol: "prometheus", + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{},"requestLogTemplate":"test template","enableProbeRequestLog":true}`, + }) + }), + }, { + name: "disabled request log configuration as env var", + rev: revision("bar", "foo", + withContainers(containers)), + oc: observability.Config{ + RequestLogTemplate: "test template", + EnableProbeRequestLog: false, + EnableRequestLog: false, }, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{"protocol":"prometheus"}}`, - }) - }), - }, { - name: "enable profiling", - rev: revision("bar", "foo", - withContainers(containers)), - oc: observability.Config{ - BaseConfig: observability.BaseConfig{ - Runtime: observability.RuntimeConfig{ - Profiling: "enabled", + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{},"requestLogTemplate":"test template"}`, + }) + }), + }, { + name: "request metrics backend as env var", + rev: revision("bar", "foo", + withContainers(containers)), + oc: observability.Config{ + RequestMetrics: observability.MetricsConfig{ + Protocol: "prometheus", }, }, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{"profiling":"enabled"},"requestMetrics":{}}`, - }) - c.Ports = append(queueNonServingPorts, profilingPort, queueHTTPPort, queueHTTPSPort) - }), - }, { - name: "custom TimeoutSeconds", - rev: revision("bar", "foo", - withContainers(containers), - func(revision *v1.Revision) { - revision.Spec.TimeoutSeconds = ptr.Int64(99) + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{"protocol":"prometheus"}}`, + }) + }), + }, { + name: "enable profiling", + rev: revision("bar", "foo", + withContainers(containers)), + oc: observability.Config{ + BaseConfig: observability.BaseConfig{ + Runtime: observability.RuntimeConfig{ + Profiling: "enabled", + }, + }, }, - ), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "REVISION_TIMEOUT_SECONDS": "99", - }) - }), - }, { - name: "custom ResponseStartTimeoutSeconds", - rev: revision("bar", "foo", - withContainers(containers), - func(revision *v1.Revision) { - revision.Spec.ResponseStartTimeoutSeconds = ptr.Int64(77) + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{"profiling":"enabled"},"requestMetrics":{}}`, + }) + c.Ports = append(queueNonServingPorts, profilingPort, queueHTTPPort, queueHTTPSPort) + }), + }, { + name: "custom TimeoutSeconds", + rev: revision("bar", "foo", + withContainers(containers), + func(revision *v1.Revision) { + revision.Spec.TimeoutSeconds = ptr.Int64(99) + }, + ), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "REVISION_TIMEOUT_SECONDS": "99", + }) + }), + }, { + name: "custom ResponseStartTimeoutSeconds", + rev: revision("bar", "foo", + withContainers(containers), + func(revision *v1.Revision) { + revision.Spec.ResponseStartTimeoutSeconds = ptr.Int64(77) + }, + ), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "REVISION_RESPONSE_START_TIMEOUT_SECONDS": "77", + }) + }), + }, { + name: "custom IdleTimeoutSeconds", + rev: revision("bar", "foo", + withContainers(containers), + func(revision *v1.Revision) { + revision.Spec.IdleTimeoutSeconds = ptr.Int64(99) + }, + ), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "REVISION_IDLE_TIMEOUT_SECONDS": "99", + }) + }), + }, { + name: "default resource config with feature qp defaults disabled", + rev: revision("bar", "foo", + withContainers(containers)), + dc: deployment.Config{ + QueueSidecarCPURequest: &deployment.QueueSidecarCPURequestDefault, }, - ), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "REVISION_RESPONSE_START_TIMEOUT_SECONDS": "77", - }) - }), - }, { - name: "custom IdleTimeoutSeconds", - rev: revision("bar", "foo", - withContainers(containers), - func(revision *v1.Revision) { - revision.Spec.IdleTimeoutSeconds = ptr.Int64(99) + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{}) + c.Resources.Requests = corev1.ResourceList{ + corev1.ResourceCPU: deployment.QueueSidecarCPURequestDefault, + } + }), + }, { + name: "resource config with feature qp defaults enabled", + rev: revision("bar", "foo", + withContainers(containers)), + dc: deployment.Config{ + QueueSidecarCPURequest: &deployment.QueueSidecarCPURequestDefault, }, - ), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "REVISION_IDLE_TIMEOUT_SECONDS": "99", - }) - }), - }, { - name: "default resource config with feature qp defaults disabled", - rev: revision("bar", "foo", - withContainers(containers)), - dc: deployment.Config{ - QueueSidecarCPURequest: &deployment.QueueSidecarCPURequestDefault, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{}) - c.Resources.Requests = corev1.ResourceList{ - corev1.ResourceCPU: deployment.QueueSidecarCPURequestDefault, - } - }), - }, { - name: "resource config with feature qp defaults enabled", - rev: revision("bar", "foo", - withContainers(containers)), - dc: deployment.Config{ - QueueSidecarCPURequest: &deployment.QueueSidecarCPURequestDefault, - }, - fc: apicfg.Features{ - QueueProxyResourceDefaults: apicfg.Enabled, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{}) - c.Resources.Requests = corev1.ResourceList{ - corev1.ResourceCPU: deployment.QueueSidecarCPURequestDefault, - corev1.ResourceMemory: deployment.QueueSidecarMemoryRequestDefault, - } - c.Resources.Limits = corev1.ResourceList{ - corev1.ResourceCPU: deployment.QueueSidecarCPULimitDefault, - corev1.ResourceMemory: deployment.QueueSidecarMemoryLimitDefault, - } - }), - }, { - name: "overridden resources", - rev: revision("bar", "foo", - withContainers(containers)), - dc: deployment.Config{ - QueueSidecarCPURequest: resourcePtr(resource.MustParse("123m")), - QueueSidecarEphemeralStorageRequest: resourcePtr(resource.MustParse("456M")), - QueueSidecarMemoryLimit: resourcePtr(resource.MustParse("789m")), - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{}) - c.Resources.Requests = corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("123m"), - corev1.ResourceEphemeralStorage: resource.MustParse("456M"), - } - c.Resources.Limits = corev1.ResourceList{ - corev1.ResourceMemory: resource.MustParse("789m"), - } - }), - }, { - name: "collector address as env var", - rev: revision("bar", "foo", - withContainers(containers)), - oc: observability.Config{ - RequestMetrics: observability.MetricsConfig{ - Protocol: "http/protobuf", - Endpoint: "otel:55678", + fc: apicfg.Features{ + QueueProxyResourceDefaults: apicfg.Enabled, }, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{"protocol":"http/protobuf","endpoint":"otel:55678"}}`, - }) - }), - }, { - name: "HTTP2 autodetection enabled", - rev: revision("bar", "foo", - withContainers(containers)), - fc: apicfg.Features{ - AutoDetectHTTP2: apicfg.Enabled, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "ENABLE_HTTP2_AUTO_DETECTION": "true", - }) - }), - }, { - name: "HTTP1 full duplex enabled", - rev: revision("bar", "foo", - withContainers(containers), - WithRevisionAnnotations(map[string]string{apicfg.AllowHTTPFullDuplexFeatureKey: string(apicfg.Enabled)})), - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "ENABLE_HTTP_FULL_DUPLEX": "true", - }) - }), - }, { - name: "set root ca", - rev: revision("bar", "foo", - withContainers(containers)), - dc: deployment.Config{ - QueueSidecarRootCA: "xyz", - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "ROOT_CA": "xyz", - }) - }), - }, { - name: "HTTP2 autodetection disabled", - rev: revision("bar", "foo", - withContainers(containers)), - fc: apicfg.Features{ - AutoDetectHTTP2: apicfg.Disabled, - }, - dc: deployment.Config{ - ProgressDeadline: 0 * time.Second, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "ENABLE_HTTP2_AUTO_DETECTION": "false", - }) - }), - }, { - name: "multi container probing enabled", - rev: revision("bar", "foo", withContainers(containers)), - fc: apicfg.Features{ - MultiContainerProbing: apicfg.Enabled, - }, - dc: deployment.Config{ - ProgressDeadline: 0 * time.Second, - }, - want: queueContainer(func(c *corev1.Container) { - c.Env = env(map[string]string{ - "ENABLE_MULTI_CONTAINER_PROBES": "true", - }) - }), - }, { - name: "multi container probing enabled with exec probes on all containers", - rev: revision("bar", "foo", withContainers([]corev1.Container{ - { - Name: servingContainerName, - ReadinessProbe: &corev1.Probe{ - ProbeHandler: corev1.ProbeHandler{ - Exec: &corev1.ExecAction{ - Command: []string{"bin/sh", "serving.sh"}, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{}) + c.Resources.Requests = corev1.ResourceList{ + corev1.ResourceCPU: deployment.QueueSidecarCPURequestDefault, + corev1.ResourceMemory: deployment.QueueSidecarMemoryRequestDefault, + } + c.Resources.Limits = corev1.ResourceList{ + corev1.ResourceCPU: deployment.QueueSidecarCPULimitDefault, + corev1.ResourceMemory: deployment.QueueSidecarMemoryLimitDefault, + } + }), + }, { + name: "overridden resources", + rev: revision("bar", "foo", + withContainers(containers)), + dc: deployment.Config{ + QueueSidecarCPURequest: resourcePtr(resource.MustParse("123m")), + QueueSidecarEphemeralStorageRequest: resourcePtr(resource.MustParse("456M")), + QueueSidecarMemoryLimit: resourcePtr(resource.MustParse("789m")), + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{}) + c.Resources.Requests = corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("123m"), + corev1.ResourceEphemeralStorage: resource.MustParse("456M"), + } + c.Resources.Limits = corev1.ResourceList{ + corev1.ResourceMemory: resource.MustParse("789m"), + } + }), + }, { + name: "collector address as env var", + rev: revision("bar", "foo", + withContainers(containers)), + oc: observability.Config{ + RequestMetrics: observability.MetricsConfig{ + Protocol: "http/protobuf", + Endpoint: "otel:55678", + }, + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "OBSERVABILITY_CONFIG": `{"tracing":{},"metrics":{},"runtime":{},"requestMetrics":{"protocol":"http/protobuf","endpoint":"otel:55678"}}`, + }) + }), + }, { + name: "HTTP2 autodetection enabled", + rev: revision("bar", "foo", + withContainers(containers)), + fc: apicfg.Features{ + AutoDetectHTTP2: apicfg.Enabled, + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "ENABLE_HTTP2_AUTO_DETECTION": "true", + }) + }), + }, { + name: "HTTP1 full duplex enabled", + rev: revision("bar", "foo", + withContainers(containers), + WithRevisionAnnotations(map[string]string{apicfg.AllowHTTPFullDuplexFeatureKey: string(apicfg.Enabled)})), + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "ENABLE_HTTP_FULL_DUPLEX": "true", + }) + }), + }, { + name: "set root ca", + rev: revision("bar", "foo", + withContainers(containers)), + dc: deployment.Config{ + QueueSidecarRootCA: "xyz", + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "ROOT_CA": "xyz", + }) + }), + }, { + name: "HTTP2 autodetection disabled", + rev: revision("bar", "foo", + withContainers(containers)), + fc: apicfg.Features{ + AutoDetectHTTP2: apicfg.Disabled, + }, + dc: deployment.Config{ + ProgressDeadline: 0 * time.Second, + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "ENABLE_HTTP2_AUTO_DETECTION": "false", + }) + }), + }, { + name: "multi container probing enabled", + rev: revision("bar", "foo", withContainers(containers)), + fc: apicfg.Features{ + MultiContainerProbing: apicfg.Enabled, + }, + dc: deployment.Config{ + ProgressDeadline: 0 * time.Second, + }, + want: queueContainer(func(c *corev1.Container) { + c.Env = env(map[string]string{ + "ENABLE_MULTI_CONTAINER_PROBES": "true", + }) + }), + }, { + name: "multi container probing enabled with exec probes on all containers", + rev: revision("bar", "foo", withContainers([]corev1.Container{ + { + Name: servingContainerName, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + Exec: &corev1.ExecAction{ + Command: []string{"bin/sh", "serving.sh"}, + }, }, }, + Ports: []corev1.ContainerPort{{ + ContainerPort: 1955, + Name: string(netapi.ProtocolH2C), + }}, }, - Ports: []corev1.ContainerPort{{ - ContainerPort: 1955, - Name: string(netapi.ProtocolH2C), - }}, - }, - { - Name: sidecarContainerName, - ReadinessProbe: &corev1.Probe{ - ProbeHandler: corev1.ProbeHandler{ - Exec: &corev1.ExecAction{ - Command: []string{"bin/sh", "sidecar.sh"}, + { + Name: sidecarContainerName, + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + Exec: &corev1.ExecAction{ + Command: []string{"bin/sh", "sidecar.sh"}, + }, }, }, }, + })), + fc: apicfg.Features{ + MultiContainerProbing: apicfg.Enabled, }, - })), - fc: apicfg.Features{ - MultiContainerProbing: apicfg.Enabled, - }, - dc: deployment.Config{ - ProgressDeadline: 0 * time.Second, + dc: deployment.Config{ + ProgressDeadline: 0 * time.Second, + }, + want: queueContainer(func(c *corev1.Container) { + c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) + c.ReadinessProbe.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort + c.Env = env(map[string]string{ + "ENABLE_MULTI_CONTAINER_PROBES": "true", + "USER_PORT": "1955", + "QUEUE_SERVING_PORT": "8013", + }) + }), }, - want: queueContainer(func(c *corev1.Container) { - c.Ports = append(queueNonServingPorts, queueHTTP2Port, queueHTTPSPort) - c.ReadinessProbe.HTTPGet.Port.IntVal = queueHTTP2Port.ContainerPort - c.Env = env(map[string]string{ - "ENABLE_MULTI_CONTAINER_PROBES": "true", - "USER_PORT": "1955", - "QUEUE_SERVING_PORT": "8013", - }) - }), - }} + } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if len(test.rev.Spec.PodSpec.Containers) == 0 { + if len(test.rev.Spec.Containers) == 0 { test.rev.Spec.PodSpec = corev1.PodSpec{ Containers: []corev1.Container{{ Name: servingContainerName, @@ -544,7 +546,7 @@ func TestMakeQueueContainerWithPercentageAnnotation(t *testing.T) { revision.Annotations = map[string]string{ serving.QueueSidecarResourcePercentageAnnotationKey: "20", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, Resources: corev1.ResourceRequirements{ @@ -569,7 +571,7 @@ func TestMakeQueueContainerWithPercentageAnnotation(t *testing.T) { revision.Annotations = map[string]string{ serving.QueueSidecarResourcePercentageAnnotationKey: "0.2", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, Resources: corev1.ResourceRequirements{ @@ -594,7 +596,7 @@ func TestMakeQueueContainerWithPercentageAnnotation(t *testing.T) { revision.Annotations = map[string]string{ serving.QueueSidecarResourcePercentageAnnotationKey: "foo", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, Resources: corev1.ResourceRequirements{ @@ -621,7 +623,7 @@ func TestMakeQueueContainerWithPercentageAnnotation(t *testing.T) { revision.Annotations = map[string]string{ serving.QueueSidecarResourcePercentageAnnotationKey: "100", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, Resources: corev1.ResourceRequirements{ @@ -682,7 +684,7 @@ func TestMakeQueueContainerWithResourceAnnotations(t *testing.T) { serving.QueueSidecarEphemeralStorageResourceRequestAnnotationKey: "500Mi", serving.QueueSidecarEphemeralStorageResourceLimitAnnotationKey: "600Mi", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, }} @@ -710,7 +712,7 @@ func TestMakeQueueContainerWithResourceAnnotations(t *testing.T) { serving.QueueSidecarMemoryResourceRequestAnnotationKey: "Gdx", serving.QueueSidecarMemoryResourceLimitAnnotationKey: "2Gi", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, }} @@ -731,7 +733,7 @@ func TestMakeQueueContainerWithResourceAnnotations(t *testing.T) { serving.QueueSidecarMemoryResourceLimitAnnotationKey: "4Gi", serving.QueueSidecarResourcePercentageAnnotationKey: "50", } - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: testProbe, Resources: corev1.ResourceRequirements{ @@ -783,7 +785,7 @@ func TestMakeQueueContainerWithResourceAnnotations(t *testing.T) { func TestProbeGenerationHTTPDefaults(t *testing.T) { rev := revision("bar", "foo", func(revision *v1.Revision) { - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, ReadinessProbe: &corev1.Probe{ ProbeHandler: corev1.ProbeHandler{ @@ -806,8 +808,9 @@ func TestProbeGenerationHTTPDefaults(t *testing.T) { Scheme: corev1.URISchemeHTTP, }, }, - PeriodSeconds: 1, - TimeoutSeconds: 10, + PeriodSeconds: 1, + TimeoutSeconds: 10, + FailureThreshold: 1, } wantProbeJSON, err := json.Marshal(expectedProbe) @@ -829,8 +832,9 @@ func TestProbeGenerationHTTPDefaults(t *testing.T) { }}, }, }, - PeriodSeconds: 1, - TimeoutSeconds: 10, + PeriodSeconds: 1, + TimeoutSeconds: 10, + FailureThreshold: 1, } }) @@ -850,7 +854,7 @@ func TestProbeGenerationHTTP(t *testing.T) { rev := revision("bar", "foo", func(revision *v1.Revision) { - revision.Spec.PodSpec.Containers = []corev1.Container{{ + revision.Spec.Containers = []corev1.Container{{ Name: servingContainerName, Ports: []corev1.ContainerPort{{ ContainerPort: userPort, @@ -877,8 +881,9 @@ func TestProbeGenerationHTTP(t *testing.T) { Scheme: corev1.URISchemeHTTPS, }, }, - PeriodSeconds: 2, - TimeoutSeconds: 10, + PeriodSeconds: 2, + TimeoutSeconds: 10, + FailureThreshold: 1, } wantProbeJSON, err := json.Marshal(expectedProbe) @@ -901,8 +906,9 @@ func TestProbeGenerationHTTP(t *testing.T) { }}, }, }, - PeriodSeconds: 2, - TimeoutSeconds: 10, + PeriodSeconds: 2, + TimeoutSeconds: 10, + FailureThreshold: 1, } }) @@ -935,6 +941,7 @@ func TestTCPProbeGeneration(t *testing.T) { }, PeriodSeconds: 0, SuccessThreshold: 3, + FailureThreshold: 1, }, rev: v1.RevisionSpec{ TimeoutSeconds: ptr.Int64(45), @@ -968,9 +975,10 @@ func TestTCPProbeGeneration(t *testing.T) { }}, }, }, - PeriodSeconds: 0, + PeriodSeconds: 1, TimeoutSeconds: 0, SuccessThreshold: 3, + FailureThreshold: 1, } c.Env = env(map[string]string{"USER_PORT": strconv.Itoa(userPort)}) }), @@ -997,8 +1005,9 @@ func TestTCPProbeGeneration(t *testing.T) { Port: intstr.FromInt(int(v1.DefaultUserPort)), }, }, - PeriodSeconds: 1, - TimeoutSeconds: 1, + PeriodSeconds: 1, + TimeoutSeconds: 1, + FailureThreshold: 1, }, want: queueContainer(func(c *corev1.Container) { c.ReadinessProbe = &corev1.Probe{ @@ -1013,7 +1022,8 @@ func TestTCPProbeGeneration(t *testing.T) { }, PeriodSeconds: 1, // Inherit Kubernetes default here rather than overriding as we need to do for exec probe. - TimeoutSeconds: 0, + TimeoutSeconds: 0, + FailureThreshold: 1, } c.Env = env(map[string]string{}) }), diff --git a/pkg/webhook/podspec_dryrun.go b/pkg/webhook/podspec_dryrun.go index 852efc101b61..8346a692c317 100644 --- a/pkg/webhook/podspec_dryrun.go +++ b/pkg/webhook/podspec_dryrun.go @@ -61,6 +61,16 @@ func validatePodSpec(ctx context.Context, ps v1.RevisionSpec, namespace string) rev.SetDefaults(ctx) podSpec := resources.BuildPodSpec(rev, resources.BuildUserContainers(rev), nil /*configs*/) + // Add the drain volume that BuildUserContainers adds volume mounts for + // This is necessary because BuildUserContainers adds the volume mount but + // the volume itself is only added in makePodSpec when cfg is not nil + podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{ + Name: "knative-drain-signal", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }) + // Make a sample pod with the template Revisions & PodSpec and dryrun call to API-server pod := &corev1.Pod{ ObjectMeta: om, diff --git a/test/e2e/websocket_test.go b/test/e2e/websocket_test.go index ec7be61ac064..4ef55e52f446 100644 --- a/test/e2e/websocket_test.go +++ b/test/e2e/websocket_test.go @@ -322,6 +322,11 @@ func TestWebSocketWithTimeout(t *testing.T) { idleTimeoutSeconds: 10, delay: "20", expectError: true, + }, { + name: "websocket does not drop after queue drain is called at 30s", + timeoutSeconds: 60, + delay: "45", + expectError: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -349,6 +354,45 @@ func TestWebSocketWithTimeout(t *testing.T) { } } +func TestWebSocketDrain(t *testing.T) { + clients := Setup(t) + + testCases := []struct { + name string + timeoutSeconds int64 + delay string + expectError bool + }{{ + name: "websocket does not drop after queue drain is called", + timeoutSeconds: 60, + delay: "45", + expectError: false, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + names := test.ResourceNames{ + Service: test.ObjectNameForTest(t), + Image: wsServerTestImageName, + } + + // Clean up in both abnormal and normal exits. + test.EnsureTearDown(t, clients, &names) + + _, err := v1test.CreateServiceReady(t, clients, &names, + rtesting.WithRevisionTimeoutSeconds(tc.timeoutSeconds), + ) + if err != nil { + t.Fatal("Failed to create WebSocket server:", err) + } + // Validate the websocket connection. + err = ValidateWebSocketConnection(t, clients, names, tc.delay) + if (err == nil && tc.expectError) || (err != nil && !tc.expectError) { + t.Error(err) + } + }) + } +} + func abs(a int) int { if a < 0 { return -a diff --git a/test/prober.go b/test/prober.go index d1d254fa9400..e6d527b2bd48 100644 --- a/test/prober.go +++ b/test/prober.go @@ -201,7 +201,7 @@ func (m *manager) SLI() (total, failures int64) { total += pt failures += pf } - return + return total, failures } // Foreach implements ProberManager diff --git a/vendor/google.golang.org/grpc/MAINTAINERS.md b/vendor/google.golang.org/grpc/MAINTAINERS.md index 5d4096d46a04..df35bb9a882a 100644 --- a/vendor/google.golang.org/grpc/MAINTAINERS.md +++ b/vendor/google.golang.org/grpc/MAINTAINERS.md @@ -9,21 +9,19 @@ for general contribution guidelines. ## Maintainers (in alphabetical order) -- [aranjans](https://github.com/aranjans), Google LLC - [arjan-bal](https://github.com/arjan-bal), Google LLC - [arvindbr8](https://github.com/arvindbr8), Google LLC - [atollena](https://github.com/atollena), Datadog, Inc. - [dfawley](https://github.com/dfawley), Google LLC - [easwars](https://github.com/easwars), Google LLC -- [erm-g](https://github.com/erm-g), Google LLC - [gtcooke94](https://github.com/gtcooke94), Google LLC -- [purnesh42h](https://github.com/purnesh42h), Google LLC -- [zasweq](https://github.com/zasweq), Google LLC ## Emeritus Maintainers (in alphabetical order) - [adelez](https://github.com/adelez) +- [aranjans](https://github.com/aranjans) - [canguler](https://github.com/canguler) - [cesarghali](https://github.com/cesarghali) +- [erm-g](https://github.com/erm-g) - [iamqizhao](https://github.com/iamqizhao) - [jeanbza](https://github.com/jeanbza) - [jtattermusch](https://github.com/jtattermusch) @@ -32,5 +30,7 @@ for general contribution guidelines. - [matt-kwong](https://github.com/matt-kwong) - [menghanl](https://github.com/menghanl) - [nicolasnoble](https://github.com/nicolasnoble) +- [purnesh42h](https://github.com/purnesh42h) - [srini100](https://github.com/srini100) - [yongni](https://github.com/yongni) +- [zasweq](https://github.com/zasweq) diff --git a/vendor/google.golang.org/grpc/balancer/endpointsharding/endpointsharding.go b/vendor/google.golang.org/grpc/balancer/endpointsharding/endpointsharding.go index 0ad6bb1f2203..360db08ebc13 100644 --- a/vendor/google.golang.org/grpc/balancer/endpointsharding/endpointsharding.go +++ b/vendor/google.golang.org/grpc/balancer/endpointsharding/endpointsharding.go @@ -37,6 +37,8 @@ import ( "google.golang.org/grpc/resolver" ) +var randIntN = rand.IntN + // ChildState is the balancer state of a child along with the endpoint which // identifies the child balancer. type ChildState struct { @@ -112,6 +114,21 @@ type endpointSharding struct { mu sync.Mutex } +// rotateEndpoints returns a slice of all the input endpoints rotated a random +// amount. +func rotateEndpoints(es []resolver.Endpoint) []resolver.Endpoint { + les := len(es) + if les == 0 { + return es + } + r := randIntN(les) + // Make a copy to avoid mutating data beyond the end of es. + ret := make([]resolver.Endpoint, les) + copy(ret, es[r:]) + copy(ret[les-r:], es[:r]) + return ret +} + // UpdateClientConnState creates a child for new endpoints and deletes children // for endpoints that are no longer present. It also updates all the children, // and sends a single synchronous update of the childrens' aggregated state at @@ -133,7 +150,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState newChildren := resolver.NewEndpointMap[*balancerWrapper]() // Update/Create new children. - for _, endpoint := range state.ResolverState.Endpoints { + for _, endpoint := range rotateEndpoints(state.ResolverState.Endpoints) { if _, ok := newChildren.Get(endpoint); ok { // Endpoint child was already created, continue to avoid duplicate // update. @@ -279,7 +296,7 @@ func (es *endpointSharding) updateState() { p := &pickerWithChildStates{ pickers: pickers, childStates: childStates, - next: uint32(rand.IntN(len(pickers))), + next: uint32(randIntN(len(pickers))), } es.cc.UpdateState(balancer.State{ ConnectivityState: aggState, diff --git a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go index e62047256afb..67f315a0dbc4 100644 --- a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ b/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go @@ -67,21 +67,21 @@ var ( disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ Name: "grpc.lb.pick_first.disconnections", Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", - Unit: "disconnection", + Unit: "{disconnection}", Labels: []string{"grpc.target"}, Default: false, }) connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ Name: "grpc.lb.pick_first.connection_attempts_succeeded", Description: "EXPERIMENTAL. Number of successful connection attempts.", - Unit: "attempt", + Unit: "{attempt}", Labels: []string{"grpc.target"}, Default: false, }) connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ Name: "grpc.lb.pick_first.connection_attempts_failed", Description: "EXPERIMENTAL. Number of failed connection attempts.", - Unit: "attempt", + Unit: "{attempt}", Labels: []string{"grpc.target"}, Default: false, }) diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index cd3eaf8ddcbd..3f762285db71 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -208,7 +208,7 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority) cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz) - cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers) + cc.pickerWrapper = newPickerWrapper() cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers) @@ -1076,13 +1076,6 @@ func (cc *ClientConn) healthCheckConfig() *healthCheckConfig { return cc.sc.healthCheckConfig } -func (cc *ClientConn) getTransport(ctx context.Context, failfast bool, method string) (transport.ClientTransport, balancer.PickResult, error) { - return cc.pickerWrapper.pick(ctx, failfast, balancer.PickInfo{ - Ctx: ctx, - FullMethodName: method, - }) -} - func (cc *ClientConn) applyServiceConfigAndBalancer(sc *ServiceConfig, configSelector iresolver.ConfigSelector) { if sc == nil { // should never reach here. @@ -1831,7 +1824,7 @@ func (cc *ClientConn) initAuthority() error { } else if auth, ok := cc.resolverBuilder.(resolver.AuthorityOverrider); ok { cc.authority = auth.OverrideAuthority(cc.parsedTarget) } else if strings.HasPrefix(endpoint, ":") { - cc.authority = "localhost" + endpoint + cc.authority = "localhost" + encodeAuthority(endpoint) } else { cc.authority = encodeAuthority(endpoint) } diff --git a/vendor/google.golang.org/grpc/credentials/credentials.go b/vendor/google.golang.org/grpc/credentials/credentials.go index a63ab606e665..c8e337cdda07 100644 --- a/vendor/google.golang.org/grpc/credentials/credentials.go +++ b/vendor/google.golang.org/grpc/credentials/credentials.go @@ -96,10 +96,11 @@ func (c CommonAuthInfo) GetCommonAuthInfo() CommonAuthInfo { return c } -// ProtocolInfo provides information regarding the gRPC wire protocol version, -// security protocol, security protocol version in use, server name, etc. +// ProtocolInfo provides static information regarding transport credentials. type ProtocolInfo struct { // ProtocolVersion is the gRPC wire protocol version. + // + // Deprecated: this is unused by gRPC. ProtocolVersion string // SecurityProtocol is the security protocol in use. SecurityProtocol string @@ -109,7 +110,16 @@ type ProtocolInfo struct { // // Deprecated: please use Peer.AuthInfo. SecurityVersion string - // ServerName is the user-configured server name. + // ServerName is the user-configured server name. If set, this overrides + // the default :authority header used for all RPCs on the channel using the + // containing credentials, unless grpc.WithAuthority is set on the channel, + // in which case that setting will take precedence. + // + // This must be a valid `:authority` header according to + // [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2). + // + // Deprecated: Users should use grpc.WithAuthority to override the authority + // on a channel instead of configuring the credentials. ServerName string } @@ -173,12 +183,17 @@ type TransportCredentials interface { // Clone makes a copy of this TransportCredentials. Clone() TransportCredentials // OverrideServerName specifies the value used for the following: + // // - verifying the hostname on the returned certificates // - as SNI in the client's handshake to support virtual hosting // - as the value for `:authority` header at stream creation time // - // Deprecated: use grpc.WithAuthority instead. Will be supported - // throughout 1.x. + // The provided string should be a valid `:authority` header according to + // [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2). + // + // Deprecated: this method is unused by gRPC. Users should use + // grpc.WithAuthority to override the authority on a channel instead of + // configuring the credentials. OverrideServerName(string) error } diff --git a/vendor/google.golang.org/grpc/credentials/tls.go b/vendor/google.golang.org/grpc/credentials/tls.go index 20f65f7bd956..8277be7d6f85 100644 --- a/vendor/google.golang.org/grpc/credentials/tls.go +++ b/vendor/google.golang.org/grpc/credentials/tls.go @@ -110,14 +110,14 @@ func (c tlsCreds) Info() ProtocolInfo { func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) { // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := credinternal.CloneTLSConfig(c.config) - if cfg.ServerName == "" { - serverName, _, err := net.SplitHostPort(authority) - if err != nil { - // If the authority had no host port or if the authority cannot be parsed, use it as-is. - serverName = authority - } - cfg.ServerName = serverName + + serverName, _, err := net.SplitHostPort(authority) + if err != nil { + // If the authority had no host port or if the authority cannot be parsed, use it as-is. + serverName = authority } + cfg.ServerName = serverName + conn := tls.Client(rawConn, cfg) errChannel := make(chan error, 1) go func() { @@ -259,9 +259,11 @@ func applyDefaults(c *tls.Config) *tls.Config { // certificates to establish the identity of the client need to be included in // the credentials (eg: for mTLS), use NewTLS instead, where a complete // tls.Config can be specified. -// serverNameOverride is for testing only. If set to a non empty string, -// it will override the virtual host name of authority (e.g. :authority header -// field) in requests. +// +// serverNameOverride is for testing only. If set to a non empty string, it will +// override the virtual host name of authority (e.g. :authority header field) in +// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to +// override the authority of the client instead. func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) TransportCredentials { return NewTLS(&tls.Config{ServerName: serverNameOverride, RootCAs: cp}) } @@ -271,9 +273,11 @@ func NewClientTLSFromCert(cp *x509.CertPool, serverNameOverride string) Transpor // certificates to establish the identity of the client need to be included in // the credentials (eg: for mTLS), use NewTLS instead, where a complete // tls.Config can be specified. -// serverNameOverride is for testing only. If set to a non empty string, -// it will override the virtual host name of authority (e.g. :authority header -// field) in requests. +// +// serverNameOverride is for testing only. If set to a non empty string, it will +// override the virtual host name of authority (e.g. :authority header field) in +// requests. Users should use grpc.WithAuthority passed to grpc.NewClient to +// override the authority of the client instead. func NewClientTLSFromFile(certFile, serverNameOverride string) (TransportCredentials, error) { b, err := os.ReadFile(certFile) if err != nil { diff --git a/vendor/google.golang.org/grpc/dialoptions.go b/vendor/google.golang.org/grpc/dialoptions.go index ec0ca89ccdca..7a5ac2e7c494 100644 --- a/vendor/google.golang.org/grpc/dialoptions.go +++ b/vendor/google.golang.org/grpc/dialoptions.go @@ -608,6 +608,8 @@ func WithChainStreamInterceptor(interceptors ...StreamClientInterceptor) DialOpt // WithAuthority returns a DialOption that specifies the value to be used as the // :authority pseudo-header and as the server name in authentication handshake. +// This overrides all other ways of setting authority on the channel, but can be +// overridden per-call by using grpc.CallAuthority. func WithAuthority(a string) DialOption { return newFuncDialOption(func(o *dialOptions) { o.authority = a diff --git a/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go b/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go index 2fdaed88dbd1..7e060f5ed132 100644 --- a/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go +++ b/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go @@ -26,26 +26,32 @@ import ( ) var ( - // TXTErrIgnore is set if TXT errors should be ignored ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). + // EnableTXTServiceConfig is set if the DNS resolver should perform TXT + // lookups for service config ("GRPC_ENABLE_TXT_SERVICE_CONFIG" is not + // "false"). + EnableTXTServiceConfig = boolFromEnv("GRPC_ENABLE_TXT_SERVICE_CONFIG", true) + + // TXTErrIgnore is set if TXT errors should be ignored + // ("GRPC_GO_IGNORE_TXT_ERRORS" is not "false"). TXTErrIgnore = boolFromEnv("GRPC_GO_IGNORE_TXT_ERRORS", true) + // RingHashCap indicates the maximum ring size which defaults to 4096 // entries but may be overridden by setting the environment variable // "GRPC_RING_HASH_CAP". This does not override the default bounds // checking which NACKs configs specifying ring sizes > 8*1024*1024 (~8M). RingHashCap = uint64FromEnv("GRPC_RING_HASH_CAP", 4096, 1, 8*1024*1024) + // ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS // handshakes that can be performed. ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100) + // EnforceALPNEnabled is set if TLS connections to servers with ALPN disabled // should be rejected. The HTTP/2 protocol requires ALPN to be enabled, this // option is present for backward compatibility. This option may be overridden // by setting the environment variable "GRPC_ENFORCE_ALPN_ENABLED" to "true" // or "false". EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true) - // XDSFallbackSupport is the env variable that controls whether support for - // xDS fallback is turned on. If this is unset or is false, only the first - // xDS server in the list of server configs will be used. - XDSFallbackSupport = boolFromEnv("GRPC_EXPERIMENTAL_XDS_FALLBACK", true) + // NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used // instead of the exiting pickfirst implementation. This can be disabled by // setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST" diff --git a/vendor/google.golang.org/grpc/internal/internal.go b/vendor/google.golang.org/grpc/internal/internal.go index 3ac798e8e60d..2699223a27f1 100644 --- a/vendor/google.golang.org/grpc/internal/internal.go +++ b/vendor/google.golang.org/grpc/internal/internal.go @@ -182,35 +182,6 @@ var ( // other features, including the CSDS service. NewXDSResolverWithClientForTesting any // func(xdsclient.XDSClient) (resolver.Builder, error) - // RegisterRLSClusterSpecifierPluginForTesting registers the RLS Cluster - // Specifier Plugin for testing purposes, regardless of the XDSRLS environment - // variable. - // - // TODO: Remove this function once the RLS env var is removed. - RegisterRLSClusterSpecifierPluginForTesting func() - - // UnregisterRLSClusterSpecifierPluginForTesting unregisters the RLS Cluster - // Specifier Plugin for testing purposes. This is needed because there is no way - // to unregister the RLS Cluster Specifier Plugin after registering it solely - // for testing purposes using RegisterRLSClusterSpecifierPluginForTesting(). - // - // TODO: Remove this function once the RLS env var is removed. - UnregisterRLSClusterSpecifierPluginForTesting func() - - // RegisterRBACHTTPFilterForTesting registers the RBAC HTTP Filter for testing - // purposes, regardless of the RBAC environment variable. - // - // TODO: Remove this function once the RBAC env var is removed. - RegisterRBACHTTPFilterForTesting func() - - // UnregisterRBACHTTPFilterForTesting unregisters the RBAC HTTP Filter for - // testing purposes. This is needed because there is no way to unregister the - // HTTP Filter after registering it solely for testing purposes using - // RegisterRBACHTTPFilterForTesting(). - // - // TODO: Remove this function once the RBAC env var is removed. - UnregisterRBACHTTPFilterForTesting func() - // ORCAAllowAnyMinReportingInterval is for examples/orca use ONLY. ORCAAllowAnyMinReportingInterval any // func(so *orca.ServiceOptions) diff --git a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go index ba5c5a95d0d7..ada5251cff3e 100644 --- a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go +++ b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go @@ -132,13 +132,13 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts // DNS address (non-IP). ctx, cancel := context.WithCancel(context.Background()) d := &dnsResolver{ - host: host, - port: port, - ctx: ctx, - cancel: cancel, - cc: cc, - rn: make(chan struct{}, 1), - disableServiceConfig: opts.DisableServiceConfig, + host: host, + port: port, + ctx: ctx, + cancel: cancel, + cc: cc, + rn: make(chan struct{}, 1), + enableServiceConfig: envconfig.EnableTXTServiceConfig && !opts.DisableServiceConfig, } d.resolver, err = internal.NewNetResolver(target.URL.Host) @@ -181,8 +181,8 @@ type dnsResolver struct { // finishes, race detector sometimes will warn lookup (READ the lookup // function pointers) inside watcher() goroutine has data race with // replaceNetFunc (WRITE the lookup function pointers). - wg sync.WaitGroup - disableServiceConfig bool + wg sync.WaitGroup + enableServiceConfig bool } // ResolveNow invoke an immediate resolution of the target that this @@ -346,7 +346,7 @@ func (d *dnsResolver) lookup() (*resolver.State, error) { if len(srv) > 0 { state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv}) } - if !d.disableServiceConfig { + if d.enableServiceConfig { state.ServiceConfig = d.lookupTXT(ctx) } return &state, nil diff --git a/vendor/google.golang.org/grpc/picker_wrapper.go b/vendor/google.golang.org/grpc/picker_wrapper.go index a2d2a798d488..aa52bfe95fd8 100644 --- a/vendor/google.golang.org/grpc/picker_wrapper.go +++ b/vendor/google.golang.org/grpc/picker_wrapper.go @@ -29,7 +29,6 @@ import ( "google.golang.org/grpc/internal/channelz" istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/transport" - "google.golang.org/grpc/stats" "google.golang.org/grpc/status" ) @@ -48,14 +47,11 @@ type pickerGeneration struct { // actions and unblock when there's a picker update. type pickerWrapper struct { // If pickerGen holds a nil pointer, the pickerWrapper is closed. - pickerGen atomic.Pointer[pickerGeneration] - statsHandlers []stats.Handler // to record blocking picker calls + pickerGen atomic.Pointer[pickerGeneration] } -func newPickerWrapper(statsHandlers []stats.Handler) *pickerWrapper { - pw := &pickerWrapper{ - statsHandlers: statsHandlers, - } +func newPickerWrapper() *pickerWrapper { + pw := &pickerWrapper{} pw.pickerGen.Store(&pickerGeneration{ blockingCh: make(chan struct{}), }) @@ -93,6 +89,12 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) { } } +type pick struct { + transport transport.ClientTransport // the selected transport + result balancer.PickResult // the contents of the pick from the LB policy + blocked bool // set if a picker call queued for a new picker +} + // pick returns the transport that will be used for the RPC. // It may block in the following cases: // - there's no picker @@ -100,15 +102,16 @@ func doneChannelzWrapper(acbw *acBalancerWrapper, result *balancer.PickResult) { // - the current picker returns other errors and failfast is false. // - the subConn returned by the current picker is not READY // When one of these situations happens, pick blocks until the picker gets updated. -func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (transport.ClientTransport, balancer.PickResult, error) { +func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.PickInfo) (pick, error) { var ch chan struct{} var lastPickErr error + pickBlocked := false for { pg := pw.pickerGen.Load() if pg == nil { - return nil, balancer.PickResult{}, ErrClientConnClosing + return pick{}, ErrClientConnClosing } if pg.picker == nil { ch = pg.blockingCh @@ -127,9 +130,9 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. } switch ctx.Err() { case context.DeadlineExceeded: - return nil, balancer.PickResult{}, status.Error(codes.DeadlineExceeded, errStr) + return pick{}, status.Error(codes.DeadlineExceeded, errStr) case context.Canceled: - return nil, balancer.PickResult{}, status.Error(codes.Canceled, errStr) + return pick{}, status.Error(codes.Canceled, errStr) } case <-ch: } @@ -145,9 +148,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. // In the second case, the only way it will get to this conditional is // if there is a new picker. if ch != nil { - for _, sh := range pw.statsHandlers { - sh.HandleRPC(ctx, &stats.PickerUpdated{}) - } + pickBlocked = true } ch = pg.blockingCh @@ -164,7 +165,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. if istatus.IsRestrictedControlPlaneCode(st) { err = status.Errorf(codes.Internal, "received picker error with illegal status: %v", err) } - return nil, balancer.PickResult{}, dropError{error: err} + return pick{}, dropError{error: err} } // For all other errors, wait for ready RPCs should block and other // RPCs should fail with unavailable. @@ -172,7 +173,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. lastPickErr = err continue } - return nil, balancer.PickResult{}, status.Error(codes.Unavailable, err.Error()) + return pick{}, status.Error(codes.Unavailable, err.Error()) } acbw, ok := pickResult.SubConn.(*acBalancerWrapper) @@ -183,9 +184,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. if t := acbw.ac.getReadyTransport(); t != nil { if channelz.IsOn() { doneChannelzWrapper(acbw, &pickResult) - return t, pickResult, nil } - return t, pickResult, nil + return pick{transport: t, result: pickResult, blocked: pickBlocked}, nil } if pickResult.Done != nil { // Calling done with nil error, no bytes sent and no bytes received. diff --git a/vendor/google.golang.org/grpc/resolver/resolver.go b/vendor/google.golang.org/grpc/resolver/resolver.go index b84ef26d46d1..8e6af9514b6d 100644 --- a/vendor/google.golang.org/grpc/resolver/resolver.go +++ b/vendor/google.golang.org/grpc/resolver/resolver.go @@ -332,6 +332,11 @@ type AuthorityOverrider interface { // OverrideAuthority returns the authority to use for a ClientConn with the // given target. The implementation must generate it without blocking, // typically in line, and must keep it unchanged. + // + // The returned string must be a valid ":authority" header value, i.e. be + // encoded according to + // [RFC3986](https://datatracker.ietf.org/doc/html/rfc3986#section-3.2) as + // necessary. OverrideAuthority(Target) string } diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 70fe23f55022..1da2a542acde 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -1598,6 +1598,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv s: stream, p: &parser{r: stream, bufferPool: s.opts.bufferPool}, codec: s.getCodec(stream.ContentSubtype()), + desc: sd, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, diff --git a/vendor/google.golang.org/grpc/stats/stats.go b/vendor/google.golang.org/grpc/stats/stats.go index baf7740efba9..10bf998aa5be 100644 --- a/vendor/google.golang.org/grpc/stats/stats.go +++ b/vendor/google.golang.org/grpc/stats/stats.go @@ -64,15 +64,21 @@ func (s *Begin) IsClient() bool { return s.Client } func (s *Begin) isRPCStats() {} -// PickerUpdated indicates that the LB policy provided a new picker while the -// RPC was waiting for one. -type PickerUpdated struct{} +// DelayedPickComplete indicates that the RPC is unblocked following a delay in +// selecting a connection for the call. +type DelayedPickComplete struct{} -// IsClient indicates if the stats information is from client side. Only Client -// Side interfaces with a Picker, thus always returns true. -func (*PickerUpdated) IsClient() bool { return true } +// IsClient indicates DelayedPickComplete is available on the client. +func (*DelayedPickComplete) IsClient() bool { return true } -func (*PickerUpdated) isRPCStats() {} +func (*DelayedPickComplete) isRPCStats() {} + +// PickerUpdated indicates that the RPC is unblocked following a delay in +// selecting a connection for the call. +// +// Deprecated: will be removed in a future release; use DelayedPickComplete +// instead. +type PickerUpdated = DelayedPickComplete // InPayload contains stats about an incoming payload. type InPayload struct { diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index ca6948926f93..d9bbd4c57cf6 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -469,8 +469,9 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) func (a *csAttempt) getTransport() error { cs := a.cs - var err error - a.transport, a.pickResult, err = cs.cc.getTransport(a.ctx, cs.callInfo.failFast, cs.callHdr.Method) + pickInfo := balancer.PickInfo{Ctx: a.ctx, FullMethodName: cs.callHdr.Method} + pick, err := cs.cc.pickerWrapper.pick(a.ctx, cs.callInfo.failFast, pickInfo) + a.transport, a.pickResult = pick.transport, pick.result if err != nil { if de, ok := err.(dropError); ok { err = de.error @@ -481,6 +482,11 @@ func (a *csAttempt) getTransport() error { if a.trInfo != nil { a.trInfo.firstLine.SetRemoteAddr(a.transport.RemoteAddr()) } + if pick.blocked { + for _, sh := range a.statsHandlers { + sh.HandleRPC(a.ctx, &stats.DelayedPickComplete{}) + } + } return nil } @@ -1580,6 +1586,7 @@ type serverStream struct { s *transport.ServerStream p *parser codec baseCodec + desc *StreamDesc compressorV0 Compressor compressorV1 encoding.Compressor @@ -1588,6 +1595,8 @@ type serverStream struct { sendCompressorName string + recvFirstMsg bool // set after the first message is received + maxReceiveMessageSize int maxSendMessageSize int trInfo *traceInfo @@ -1774,6 +1783,10 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, chc) } } + // Received no request msg for non-client streaming rpcs. + if !ss.desc.ClientStreams && !ss.recvFirstMsg { + return status.Error(codes.Internal, "cardinality violation: received no request message from non-client-streaming RPC") + } return err } if err == io.ErrUnexpectedEOF { @@ -1781,6 +1794,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } return toRPCErr(err) } + ss.recvFirstMsg = true if len(ss.statsHandler) != 0 { for _, sh := range ss.statsHandler { sh.HandleRPC(ss.s.Context(), &stats.InPayload{ @@ -1800,7 +1814,19 @@ func (ss *serverStream) RecvMsg(m any) (err error) { binlog.Log(ss.ctx, cm) } } - return nil + + if ss.desc.ClientStreams { + // Subsequent messages should be received by subsequent RecvMsg calls. + return nil + } + // Special handling for non-client-stream rpcs. + // This recv expects EOF or errors, so we don't collect inPayload. + if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + return nil + } else if err != nil { + return err + } + return status.Error(codes.Internal, "cardinality violation: received multiple request messages for non-client-streaming RPC") } // MethodFromServerStream returns the method string for the input stream. diff --git a/vendor/google.golang.org/grpc/version.go b/vendor/google.golang.org/grpc/version.go index 8b0e5f973d6d..bc1eb290f690 100644 --- a/vendor/google.golang.org/grpc/version.go +++ b/vendor/google.golang.org/grpc/version.go @@ -19,4 +19,4 @@ package grpc // Version is the current grpc version. -const Version = "1.74.2" +const Version = "1.75.0" diff --git a/vendor/modules.txt b/vendor/modules.txt index 21684e58a8ac..19b44aab2cc6 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -766,15 +766,15 @@ google.golang.org/api/option google.golang.org/api/option/internaloption google.golang.org/api/transport/http google.golang.org/api/transport/http/internal/propagation -# google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 +# google.golang.org/genproto/googleapis/api v0.0.0-20250707201910-8d1bb00bc6a7 ## explicit; go 1.23.0 google.golang.org/genproto/googleapis/api/httpbody -# google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 +# google.golang.org/genproto/googleapis/rpc v0.0.0-20250818200422-3122310a409c ## explicit; go 1.23.0 google.golang.org/genproto/googleapis/rpc/code google.golang.org/genproto/googleapis/rpc/errdetails google.golang.org/genproto/googleapis/rpc/status -# google.golang.org/grpc v1.74.2 +# google.golang.org/grpc v1.75.0 ## explicit; go 1.23.0 google.golang.org/grpc google.golang.org/grpc/attributes