From 63af521d153c7f64235d386166dd8e1e0d001ad8 Mon Sep 17 00:00:00 2001 From: Jaydip Gabani Date: Wed, 27 Mar 2024 23:20:36 +0000 Subject: [PATCH 1/4] refactoring to remove pubsub flags to improve experience Signed-off-by: Jaydip Gabani --- Makefile | 2 -- cmd/build/helmify/replacements.go | 2 -- cmd/build/helmify/static/values.yaml | 2 -- .../gatekeeper-audit-deployment.yaml | 2 -- .../charts/gatekeeper/values.yaml | 2 -- pkg/audit/manager.go | 4 +--- pkg/audit/stats_reporter.go | 1 - .../constrainttemplate/stats_reporter.go | 1 - pkg/controller/expansion/stats_reporter.go | 1 - pkg/pubsub/connection/connection.go | 2 +- pkg/pubsub/dapr/dapr.go | 20 ++++++++++++++-- pkg/pubsub/dapr/dapr_test.go | 24 ++++--------------- pkg/pubsub/dapr/fake_dapr_client.go | 11 ++++++++- pkg/pubsub/system.go | 18 ++++++++------ pkg/pubsub/system_test.go | 15 ++++-------- test/pubsub/publish-components.yaml | 3 ++- test/testutils/controller.go | 1 - website/docs/pubsub.md | 17 ++++++------- 18 files changed, 62 insertions(+), 66 deletions(-) diff --git a/Makefile b/Makefile index faccede490f..43cd57e62bf 100644 --- a/Makefile +++ b/Makefile @@ -224,8 +224,6 @@ ifeq ($(ENABLE_PUBSUB),true) --set disabledBuiltins={http.send} \ --set logMutations=true \ --set audit.enablePubsub=${ENABLE_PUBSUB} \ - --set audit.connection=${AUDIT_CONNECTION} \ - --set audit.channel=${AUDIT_CHANNEL} \ --set-string auditPodAnnotations.dapr\\.io/enabled=true \ --set-string auditPodAnnotations.dapr\\.io/app-id=audit \ --set-string auditPodAnnotations.dapr\\.io/metrics-port=9999 \ diff --git a/cmd/build/helmify/replacements.go b/cmd/build/helmify/replacements.go index a47238dbd64..d6d47fc3473 100644 --- a/cmd/build/helmify/replacements.go +++ b/cmd/build/helmify/replacements.go @@ -95,8 +95,6 @@ var replacements = map[string]string{ "- HELMSUBST_PUBSUB_ARGS": `{{ if .Values.audit.enablePubsub}} - --enable-pub-sub={{ .Values.audit.enablePubsub }} - - --audit-connection={{ .Values.audit.connection }} - - --audit-channel={{ .Values.audit.channel }} {{- end }}`, "HELMSUBST_MUTATING_WEBHOOK_FAILURE_POLICY": `{{ .Values.mutatingWebhookFailurePolicy }}`, diff --git a/cmd/build/helmify/static/values.yaml b/cmd/build/helmify/static/values.yaml index b5c09987127..475bb2f4e5d 100644 --- a/cmd/build/helmify/static/values.yaml +++ b/cmd/build/helmify/static/values.yaml @@ -212,8 +212,6 @@ controllerManager: # - ipBlock: # cidr: 0.0.0.0/0 audit: - enablePubsub: false - connection: audit-connection channel: audit-channel hostNetwork: false dnsPolicy: ClusterFirst diff --git a/manifest_staging/charts/gatekeeper/templates/gatekeeper-audit-deployment.yaml b/manifest_staging/charts/gatekeeper/templates/gatekeeper-audit-deployment.yaml index 77da1a4531a..0cbfbbc8b7a 100644 --- a/manifest_staging/charts/gatekeeper/templates/gatekeeper-audit-deployment.yaml +++ b/manifest_staging/charts/gatekeeper/templates/gatekeeper-audit-deployment.yaml @@ -68,8 +68,6 @@ spec: - --operation=status {{ if .Values.audit.enablePubsub}} - --enable-pub-sub={{ .Values.audit.enablePubsub }} - - --audit-connection={{ .Values.audit.connection }} - - --audit-channel={{ .Values.audit.channel }} {{- end }} {{ if not .Values.disableMutation}}- --operation=mutation-status{{- end }} - --logtostderr diff --git a/manifest_staging/charts/gatekeeper/values.yaml b/manifest_staging/charts/gatekeeper/values.yaml index b5c09987127..475bb2f4e5d 100644 --- a/manifest_staging/charts/gatekeeper/values.yaml +++ b/manifest_staging/charts/gatekeeper/values.yaml @@ -212,8 +212,6 @@ controllerManager: # - ipBlock: # cidr: 0.0.0.0/0 audit: - enablePubsub: false - connection: audit-connection channel: audit-channel hostNetwork: false dnsPolicy: ClusterFirst diff --git a/pkg/audit/manager.go b/pkg/audit/manager.go index 8859c429149..6b85793bdc7 100644 --- a/pkg/audit/manager.go +++ b/pkg/audit/manager.go @@ -67,8 +67,6 @@ var ( auditEventsInvolvedNamespace = flag.Bool("audit-events-involved-namespace", false, "emit audit events for each violation in the involved objects namespace, the default (false) generates events in the namespace Gatekeeper is installed in. Audit events from cluster-scoped resources will still follow the default behavior") auditMatchKindOnly = flag.Bool("audit-match-kind-only", false, "only use kinds specified in all constraints for auditing cluster resources. if kind is not specified in any of the constraints, it will audit all resources (same as setting this flag to false)") apiCacheDir = flag.String("api-cache-dir", defaultAPICacheDir, "The directory where audit from api server cache are stored, defaults to /tmp/audit") - auditConnection = flag.String("audit-connection", defaultConnection, "Connection name for publishing audit violation messages. Defaults to audit-connection") - auditChannel = flag.String("audit-channel", defaultChannel, "Channel name for publishing audit violation messages. Defaults to audit-channel") emptyAuditResults = newLimitQueue(0) logStatsAudit = flag.Bool("log-stats-audit", false, "(alpha) log stats metrics for the audit run") ) @@ -901,7 +899,7 @@ func (am *Manager) addAuditResponsesToUpdateLists( labels := r.obj.GetLabels() logViolation(am.log, constraint, ea, gvk, namespace, name, msg, details, labels) if *pubsubController.PubsubEnabled { - err := am.pubsubSystem.Publish(context.Background(), *auditConnection, *auditChannel, violationMsg(constraint, ea, gvk, namespace, name, msg, details, labels, timestamp)) + err := am.pubsubSystem.Publish(context.Background(), violationMsg(constraint, ea, gvk, namespace, name, msg, details, labels, timestamp)) if err != nil { am.log.Error(err, "pubsub audit Publishing") } diff --git a/pkg/audit/stats_reporter.go b/pkg/audit/stats_reporter.go index ca759a2e0a2..8fb5c4678d9 100644 --- a/pkg/audit/stats_reporter.go +++ b/pkg/audit/stats_reporter.go @@ -97,7 +97,6 @@ func newStatsReporter() (*reporter, error) { metric.WithDescription("Total number of audited violations"), metric.WithInt64Callback(r.observeTotalViolations), ) - if err != nil { return nil, err } diff --git a/pkg/controller/constrainttemplate/stats_reporter.go b/pkg/controller/constrainttemplate/stats_reporter.go index 68ac2781496..77754344622 100644 --- a/pkg/controller/constrainttemplate/stats_reporter.go +++ b/pkg/controller/constrainttemplate/stats_reporter.go @@ -56,7 +56,6 @@ func newStatsReporter() *reporter { metric.WithDescription(ctDesc), metric.WithInt64Callback(r.observeCTM), ) - if err != nil { panic(err) } diff --git a/pkg/controller/expansion/stats_reporter.go b/pkg/controller/expansion/stats_reporter.go index ebc225c53a8..a919719012a 100644 --- a/pkg/controller/expansion/stats_reporter.go +++ b/pkg/controller/expansion/stats_reporter.go @@ -25,7 +25,6 @@ func newRegistry() *etRegistry { etMetricName, metric.WithDescription(etDesc), metric.WithInt64Callback(r.observeETM)) - if err != nil { panic(err) } diff --git a/pkg/pubsub/connection/connection.go b/pkg/pubsub/connection/connection.go index 0edb6a74daf..789714c79a8 100644 --- a/pkg/pubsub/connection/connection.go +++ b/pkg/pubsub/connection/connection.go @@ -7,7 +7,7 @@ import ( // PubSub is the interface that wraps pubsub methods. type Connection interface { // Publish single message over a specific topic/channel - Publish(ctx context.Context, data interface{}, topic string) error + Publish(ctx context.Context, data interface{}) error // Close connections CloseConnection() error diff --git a/pkg/pubsub/dapr/dapr.go b/pkg/pubsub/dapr/dapr.go index 0db60445494..4435d8d4394 100644 --- a/pkg/pubsub/dapr/dapr.go +++ b/pkg/pubsub/dapr/dapr.go @@ -12,6 +12,9 @@ import ( type ClientConfig struct { // Name of the component to be used for pub sub messaging Component string `json:"component"` + + // Topic where the messages would be published for the connection + Topic string `json:"topic"` } // Dapr represents driver for interacting with pub sub using dapr. @@ -21,19 +24,22 @@ type Dapr struct { // Name of the pubsub component pubSubComponent string + + // Topic where the messages would be published for the connection + topic string } const ( Name = "dapr" ) -func (r *Dapr) Publish(_ context.Context, data interface{}, topic string) error { +func (r *Dapr) Publish(_ context.Context, data interface{}) error { jsonData, err := json.Marshal(data) if err != nil { return fmt.Errorf("error marshaling data: %w", err) } - err = r.client.PublishEvent(context.Background(), r.pubSubComponent, topic, jsonData) + err = r.client.PublishEvent(context.Background(), r.pubSubComponent, r.topic, jsonData) if err != nil { return fmt.Errorf("error publishing message to dapr: %w", err) } @@ -56,6 +62,11 @@ func (r *Dapr) UpdateConnection(_ context.Context, config interface{}) error { return fmt.Errorf("failed to get value of component") } r.pubSubComponent = cfg.Component + cfg.Topic, ok = m["topic"].(string) + if !ok { + return fmt.Errorf("failed to get value of topic") + } + r.topic = cfg.Topic return nil } @@ -70,6 +81,10 @@ func NewConnection(_ context.Context, config interface{}) (connection.Connection if !ok { return nil, fmt.Errorf("failed to get value of component") } + cfg.Topic, ok = m["topic"].(string) + if !ok { + return nil, fmt.Errorf("failed to get value of topic") + } tmp, err := daprClient.NewClient() if err != nil { @@ -79,5 +94,6 @@ func NewConnection(_ context.Context, config interface{}) (connection.Connection return &Dapr{ client: tmp, pubSubComponent: cfg.Component, + topic: cfg.Topic, }, nil } diff --git a/pkg/pubsub/dapr/dapr_test.go b/pkg/pubsub/dapr/dapr_test.go index 5a2e72615b1..4ab05a0d42d 100644 --- a/pkg/pubsub/dapr/dapr_test.go +++ b/pkg/pubsub/dapr/dapr_test.go @@ -55,9 +55,8 @@ func TestDapr_Publish(t *testing.T) { ctx := context.Background() type args struct { - ctx context.Context - data interface{} - topic string + ctx context.Context + data interface{} } tests := []struct { @@ -72,35 +71,22 @@ func TestDapr_Publish(t *testing.T) { data: map[string]interface{}{ "test": "test", }, - topic: "test", }, wantErr: false, }, { name: "test publish without data", args: args{ - ctx: ctx, - data: nil, - topic: "test", + ctx: ctx, + data: nil, }, wantErr: false, }, - { - name: "test publish without topic", - args: args{ - ctx: ctx, - data: map[string]interface{}{ - "test": "test", - }, - topic: "", - }, - wantErr: true, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := testClient - if err := r.Publish(tt.args.ctx, tt.args.data, tt.args.topic); (err != nil) != tt.wantErr { + if err := r.Publish(tt.args.ctx, tt.args.data); (err != nil) != tt.wantErr { t.Errorf("Dapr.Publish() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/pkg/pubsub/dapr/fake_dapr_client.go b/pkg/pubsub/dapr/fake_dapr_client.go index f8304d4a517..ebcbcad2816 100644 --- a/pkg/pubsub/dapr/fake_dapr_client.go +++ b/pkg/pubsub/dapr/fake_dapr_client.go @@ -328,6 +328,7 @@ func FakeConnection() (connection.Connection, func()) { return &Dapr{ client: c, pubSubComponent: "test", + topic: "test", }, f } @@ -338,11 +339,14 @@ type FakeDapr struct { // Name of the pubsub component pubSubComponent string + // Name of the topic + topic string + // closing function f func() } -func (r *FakeDapr) Publish(_ context.Context, _ interface{}, _ string) error { +func (r *FakeDapr) Publish(_ context.Context, _ interface{}) error { return nil } @@ -376,6 +380,10 @@ func FakeNewConnection(ctx context.Context, config interface{}) (connection.Conn if !ok { return nil, fmt.Errorf("failed to get value of component") } + cfg.Topic, ok = m["topic"].(string) + if !ok { + return nil, fmt.Errorf("failed to get value of topic") + } c, f := getTestClient(ctx) @@ -383,5 +391,6 @@ func FakeNewConnection(ctx context.Context, config interface{}) (connection.Conn client: c, pubSubComponent: cfg.Component, f: f, + topic: cfg.Topic, }, nil } diff --git a/pkg/pubsub/system.go b/pkg/pubsub/system.go index da60a1be8e6..76ff56a46bb 100644 --- a/pkg/pubsub/system.go +++ b/pkg/pubsub/system.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "errors" "fmt" "sync" @@ -19,16 +20,19 @@ func NewSystem() *System { return &System{} } -func (s *System) Publish(_ context.Context, connection string, topic string, msg interface{}) error { +func (s *System) Publish(ctx context.Context, msg interface{}) error { s.mux.RLock() defer s.mux.RUnlock() - if len(s.connections) > 0 { - if c, ok := s.connections[connection]; ok { - return c.Publish(context.Background(), msg, topic) - } - return fmt.Errorf("connection is not initialized, name: %s ", connection) + var errs error + + if len(s.connections) == 0 { + return fmt.Errorf("no connections are established") + } + + for _, c := range s.connections { + errs = errors.Join(errs, c.Publish(ctx, msg)) } - return fmt.Errorf("No connections are established") + return errs } func (s *System) UpsertConnection(ctx context.Context, config interface{}, name string, provider string) error { diff --git a/pkg/pubsub/system_test.go b/pkg/pubsub/system_test.go index a58e43c8cb7..8b8b2108d1a 100644 --- a/pkg/pubsub/system_test.go +++ b/pkg/pubsub/system_test.go @@ -24,6 +24,7 @@ func TestMain(m *testing.M) { cfg := map[string]interface{}{ dapr.Name: map[string]interface{}{ "component": "pubsub", + "topic": "audit", }, } for name, fakeConn := range tmp { @@ -90,6 +91,7 @@ func TestSystem_UpsertConnection(t *testing.T) { ctx: context.Background(), config: map[string]interface{}{ "component": "pubsub", + "topic": "test", }, name: "dapr", provider: "dapr", @@ -111,6 +113,7 @@ func TestSystem_UpsertConnection(t *testing.T) { ctx: context.Background(), config: map[string]interface{}{ "component": "pubsub", + "topic": "test", }, name: "audit", provider: "test", @@ -133,6 +136,7 @@ func TestSystem_UpsertConnection(t *testing.T) { ctx: context.Background(), config: map[string]interface{}{ "component": "test", + "topic": "audit", }, name: "audit", provider: "dapr", @@ -222,15 +226,6 @@ func TestSystem_Publish(t *testing.T) { args: args{ctx: context.Background(), connection: "audit", topic: "test", msg: nil}, wantErr: true, }, - { - name: "Publishing to a connection that does not exist", - fields: fields{ - connections: map[string]connection.Connection{"audit": &dapr.Dapr{}}, - providers: map[string]string{"audit": "dapr"}, - }, - args: args{ctx: context.Background(), connection: "test", topic: "test", msg: nil}, - wantErr: true, - }, { name: "Publishing to a connection that does exist", fields: fields{ @@ -248,7 +243,7 @@ func TestSystem_Publish(t *testing.T) { connections: tt.fields.connections, providers: tt.fields.providers, } - if err := s.Publish(tt.args.ctx, tt.args.connection, tt.args.topic, tt.args.msg); (err != nil) != tt.wantErr { + if err := s.Publish(tt.args.ctx, tt.args.msg); (err != nil) != tt.wantErr { t.Errorf("System.Publish() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/test/pubsub/publish-components.yaml b/test/pubsub/publish-components.yaml index 9686935dd01..af0981bb0cc 100644 --- a/test/pubsub/publish-components.yaml +++ b/test/pubsub/publish-components.yaml @@ -24,5 +24,6 @@ data: provider: "dapr" config: | { - "component": "pubsub" + "component": "pubsub", + "topic": "audit-channel" } diff --git a/test/testutils/controller.go b/test/testutils/controller.go index 7dd11303e3a..146ea9a9a02 100644 --- a/test/testutils/controller.go +++ b/test/testutils/controller.go @@ -116,7 +116,6 @@ func DeleteObjectAndConfirm(ctx context.Context, t *testing.T, c client.Client, s, _ := json.MarshalIndent(toGet, "", " ") return fmt.Errorf("found %v %v:\n%s", gvk, key, string(s)) }) - if err != nil { t.Fatal(err) } diff --git a/website/docs/pubsub.md b/website/docs/pubsub.md index 8c1df5fb3c0..98c53e64714 100644 --- a/website/docs/pubsub.md +++ b/website/docs/pubsub.md @@ -19,7 +19,7 @@ Install prerequisites such as a pubsub tool, a message broker etc. ### Setting up audit with pubsub enabled -In the audit deployment, set the `--enable-pub-sub` flag to `true` to publish audit violations. Additionally, use `--audit-connection` (defaults to `audit-connection`) and `--audit-channel`(defaults to `audit-channel`) flags to allow audit to publish violations using desired connection onto desired channel. `--audit-connection` must be set to the name of the connection config, and `--audit-channel` must be set to name of the channel where violations should get published. +In the audit deployment, set the `--enable-pub-sub` flag to `true` to publish audit violations. A ConfigMap that contains `provider` and `config` fields in `data` is required to establish connection for sending violations over the channel. Following is an example ConfigMap to establish a connection that uses Dapr to publish messages: @@ -33,7 +33,8 @@ data: provider: "dapr" config: | { - "component": "pubsub" + "component": "pubsub", + "topic": "audit-channel" } ``` @@ -125,6 +126,9 @@ Dapr: https://dapr.io/ - name: go-sub image: fake-subscriber:latest imagePullPolicy: Never + env: + - name: AUDIT_CHANNEL + value: "audit-channel" ``` > [!IMPORTANT] @@ -156,15 +160,13 @@ Dapr: https://dapr.io/ EOF ``` -2. To upgrade or install Gatekeeper with `--enable-pub-sub` set to `true`, `--audit-connection` set to `audit-connection`, `--audit-channel` set to `audit-channel` on audit pod. +2. To upgrade or install Gatekeeper with `--enable-pub-sub` set to `true` on audit pod. ```shell # auditPodAnnotations is used to add annotations required by Dapr to inject sidecar to audit pod echo 'auditPodAnnotations: {dapr.io/enabled: "true", dapr.io/app-id: "audit", dapr.io/metrics-port: "9999", dapr.io/sidecar-seccomp-profile-type: "RuntimeDefault"}' > /tmp/annotations.yaml helm upgrade --install gatekeeper gatekeeper/gatekeeper --namespace gatekeeper-system \ --set audit.enablePubsub=true \ - --set audit.connection=audit-connection \ - --set audit.channel=audit-channel \ --values /tmp/annotations.yaml ``` @@ -183,13 +185,12 @@ Dapr: https://dapr.io/ provider: "dapr" config: | { - "component": "pubsub" + "component": "pubsub", + "topic": "audit-channel" } EOF ``` - **Note:** Name of the connection configMap must match the value of `--audit-connection` for it to be used by audit to publish violation. At the moment, only one connection config can exists for audit. - 4. Create the constraint templates and constraints, and make sure audit ran by checking constraints. If constraint status is updated with information such as `auditTimeStamp` or `totalViolations`, then audit has ran at least once. Additionally, populated `TOTAL-VIOLATIONS` field for all constraints while listing constraints also indicates that audit has ran at least once. ```log From d7733ddbf7a3ad28558ed7777a1a3da12b52fa3a Mon Sep 17 00:00:00 2001 From: Jaydip Gabani Date: Thu, 28 Mar 2024 12:07:37 +0000 Subject: [PATCH 2/4] fixing tests Signed-off-by: Jaydip Gabani --- pkg/pubsub/dapr/dapr_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/pubsub/dapr/dapr_test.go b/pkg/pubsub/dapr/dapr_test.go index 4ab05a0d42d..f857560b50a 100644 --- a/pkg/pubsub/dapr/dapr_test.go +++ b/pkg/pubsub/dapr/dapr_test.go @@ -103,6 +103,7 @@ func TestDapr_UpdateConnection(t *testing.T) { name: "test update connection", config: map[string]interface{}{ "component": "foo", + "topic": "bar", }, wantErr: false, }, From 2d279285040a688e023cbaafa0c5620be0e13017 Mon Sep 17 00:00:00 2001 From: Jaydip Gabani Date: Tue, 9 Apr 2024 18:41:11 +0000 Subject: [PATCH 3/4] correcting enablepubsub variable Signed-off-by: Jaydip Gabani --- cmd/build/helmify/static/values.yaml | 2 +- manifest_staging/charts/gatekeeper/values.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/build/helmify/static/values.yaml b/cmd/build/helmify/static/values.yaml index ff8bce24df7..0ce91c4816f 100644 --- a/cmd/build/helmify/static/values.yaml +++ b/cmd/build/helmify/static/values.yaml @@ -214,7 +214,7 @@ controllerManager: # - ipBlock: # cidr: 0.0.0.0/0 audit: - channel: audit-channel + enablePubSub: false hostNetwork: false dnsPolicy: ClusterFirst metricsPort: 8888 diff --git a/manifest_staging/charts/gatekeeper/values.yaml b/manifest_staging/charts/gatekeeper/values.yaml index ff8bce24df7..0ce91c4816f 100644 --- a/manifest_staging/charts/gatekeeper/values.yaml +++ b/manifest_staging/charts/gatekeeper/values.yaml @@ -214,7 +214,7 @@ controllerManager: # - ipBlock: # cidr: 0.0.0.0/0 audit: - channel: audit-channel + enablePubSub: false hostNetwork: false dnsPolicy: ClusterFirst metricsPort: 8888 From aecd00db0a6324a7a462c3afa296d53769f81048 Mon Sep 17 00:00:00 2001 From: Jaydip Gabani Date: Tue, 9 Apr 2024 21:13:52 +0000 Subject: [PATCH 4/4] fixing vulns Signed-off-by: Jaydip Gabani --- go.mod | 2 +- go.sum | 4 +- vendor/golang.org/x/net/http2/frame.go | 31 ++ vendor/golang.org/x/net/http2/pipe.go | 11 +- vendor/golang.org/x/net/http2/server.go | 13 +- vendor/golang.org/x/net/http2/testsync.go | 331 +++++++++++++++++++++ vendor/golang.org/x/net/http2/transport.go | 298 +++++++++++++++---- vendor/modules.txt | 2 +- 8 files changed, 620 insertions(+), 72 deletions(-) create mode 100644 vendor/golang.org/x/net/http2/testsync.go diff --git a/go.mod b/go.mod index 3df52a4954f..f3d1eb35343 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( go.opentelemetry.io/otel/sdk/metric v1.19.0 go.uber.org/automaxprocs v1.5.3 go.uber.org/zap v1.26.0 - golang.org/x/net v0.22.0 + golang.org/x/net v0.23.0 golang.org/x/oauth2 v0.17.0 golang.org/x/sync v0.6.0 golang.org/x/time v0.5.0 diff --git a/go.sum b/go.sum index ca6693945bb..4244788dc45 100644 --- a/go.sum +++ b/go.sum @@ -447,8 +447,8 @@ golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ= golang.org/x/oauth2 v0.17.0/go.mod h1:OzPDGQiuQMguemayvdylqddI7qcD9lnSDb+1FiwQ5HA= diff --git a/vendor/golang.org/x/net/http2/frame.go b/vendor/golang.org/x/net/http2/frame.go index e2b298d8593..43557ab7e97 100644 --- a/vendor/golang.org/x/net/http2/frame.go +++ b/vendor/golang.org/x/net/http2/frame.go @@ -1564,6 +1564,7 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { if size > remainSize { hdec.SetEmitEnabled(false) mh.Truncated = true + remainSize = 0 return } remainSize -= size @@ -1576,6 +1577,36 @@ func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) { var hc headersOrContinuation = hf for { frag := hc.HeaderBlockFragment() + + // Avoid parsing large amounts of headers that we will then discard. + // If the sender exceeds the max header list size by too much, + // skip parsing the fragment and close the connection. + // + // "Too much" is either any CONTINUATION frame after we've already + // exceeded the max header list size (in which case remainSize is 0), + // or a frame whose encoded size is more than twice the remaining + // header list bytes we're willing to accept. + if int64(len(frag)) > int64(2*remainSize) { + if VerboseLogs { + log.Printf("http2: header list too large") + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + + // Also close the connection after any CONTINUATION frame following an + // invalid header, since we stop tracking the size of the headers after + // an invalid one. + if invalid != nil { + if VerboseLogs { + log.Printf("http2: invalid header: %v", invalid) + } + // It would be nice to send a RST_STREAM before sending the GOAWAY, + // but the structure of the server's frame writer makes this difficult. + return nil, ConnectionError(ErrCodeProtocol) + } + if _, err := hdec.Write(frag); err != nil { return nil, ConnectionError(ErrCodeCompression) } diff --git a/vendor/golang.org/x/net/http2/pipe.go b/vendor/golang.org/x/net/http2/pipe.go index 684d984fd96..3b9f06b9624 100644 --- a/vendor/golang.org/x/net/http2/pipe.go +++ b/vendor/golang.org/x/net/http2/pipe.go @@ -77,7 +77,10 @@ func (p *pipe) Read(d []byte) (n int, err error) { } } -var errClosedPipeWrite = errors.New("write on closed buffer") +var ( + errClosedPipeWrite = errors.New("write on closed buffer") + errUninitializedPipeWrite = errors.New("write on uninitialized buffer") +) // Write copies bytes from p into the buffer and wakes a reader. // It is an error to write more data than the buffer can hold. @@ -91,6 +94,12 @@ func (p *pipe) Write(d []byte) (n int, err error) { if p.err != nil || p.breakErr != nil { return 0, errClosedPipeWrite } + // pipe.setBuffer is never invoked, leaving the buffer uninitialized. + // We shouldn't try to write to an uninitialized pipe, + // but returning an error is better than panicking. + if p.b == nil { + return 0, errUninitializedPipeWrite + } return p.b.Write(d) } diff --git a/vendor/golang.org/x/net/http2/server.go b/vendor/golang.org/x/net/http2/server.go index ae94c6408d5..ce2e8b40eee 100644 --- a/vendor/golang.org/x/net/http2/server.go +++ b/vendor/golang.org/x/net/http2/server.go @@ -124,6 +124,7 @@ type Server struct { // IdleTimeout specifies how long until idle clients should be // closed with a GOAWAY frame. PING frames are not considered // activity for the purposes of IdleTimeout. + // If zero or negative, there is no timeout. IdleTimeout time.Duration // MaxUploadBufferPerConnection is the size of the initial flow @@ -434,7 +435,7 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) { // passes the connection off to us with the deadline already set. // Write deadlines are set per stream in serverConn.newStream. // Disarm the net.Conn write deadline here. - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { sc.conn.SetWriteDeadline(time.Time{}) } @@ -924,7 +925,7 @@ func (sc *serverConn) serve() { sc.setConnState(http.StateActive) sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer = time.AfterFunc(sc.srv.IdleTimeout, sc.onIdleTimer) defer sc.idleTimer.Stop() } @@ -1637,7 +1638,7 @@ func (sc *serverConn) closeStream(st *stream, err error) { delete(sc.streams, st.id) if len(sc.streams) == 0 { sc.setConnState(http.StateIdle) - if sc.srv.IdleTimeout != 0 { + if sc.srv.IdleTimeout > 0 { sc.idleTimer.Reset(sc.srv.IdleTimeout) } if h1ServerKeepAlivesDisabled(sc.hs) { @@ -2017,7 +2018,7 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { // similar to how the http1 server works. Here it's // technically more like the http1 Server's ReadHeaderTimeout // (in Go 1.8), though. That's a more sane option anyway. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout) } @@ -2038,7 +2039,7 @@ func (sc *serverConn) upgradeRequest(req *http.Request) { // Disable any read deadline set by the net/http package // prior to the upgrade. - if sc.hs.ReadTimeout != 0 { + if sc.hs.ReadTimeout > 0 { sc.conn.SetReadDeadline(time.Time{}) } @@ -2116,7 +2117,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream st.flow.conn = &sc.flow // link to conn-level counter st.flow.add(sc.initialStreamSendWindowSize) st.inflow.init(sc.srv.initialStreamRecvWindowSize()) - if sc.hs.WriteTimeout != 0 { + if sc.hs.WriteTimeout > 0 { st.writeDeadline = time.AfterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) } diff --git a/vendor/golang.org/x/net/http2/testsync.go b/vendor/golang.org/x/net/http2/testsync.go new file mode 100644 index 00000000000..61075bd16d3 --- /dev/null +++ b/vendor/golang.org/x/net/http2/testsync.go @@ -0,0 +1,331 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package http2 + +import ( + "context" + "sync" + "time" +) + +// testSyncHooks coordinates goroutines in tests. +// +// For example, a call to ClientConn.RoundTrip involves several goroutines, including: +// - the goroutine running RoundTrip; +// - the clientStream.doRequest goroutine, which writes the request; and +// - the clientStream.readLoop goroutine, which reads the response. +// +// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines +// are blocked waiting for some condition such as reading the Request.Body or waiting for +// flow control to become available. +// +// The testSyncHooks also manage timers and synthetic time in tests. +// This permits us to, for example, start a request and cause it to time out waiting for +// response headers without resorting to time.Sleep calls. +type testSyncHooks struct { + // active/inactive act as a mutex and condition variable. + // + // - neither chan contains a value: testSyncHooks is locked. + // - active contains a value: unlocked, and at least one goroutine is not blocked + // - inactive contains a value: unlocked, and all goroutines are blocked + active chan struct{} + inactive chan struct{} + + // goroutine counts + total int // total goroutines + condwait map[*sync.Cond]int // blocked in sync.Cond.Wait + blocked []*testBlockedGoroutine // otherwise blocked + + // fake time + now time.Time + timers []*fakeTimer + + // Transport testing: Report various events. + newclientconn func(*ClientConn) + newstream func(*clientStream) +} + +// testBlockedGoroutine is a blocked goroutine. +type testBlockedGoroutine struct { + f func() bool // blocked until f returns true + ch chan struct{} // closed when unblocked +} + +func newTestSyncHooks() *testSyncHooks { + h := &testSyncHooks{ + active: make(chan struct{}, 1), + inactive: make(chan struct{}, 1), + condwait: map[*sync.Cond]int{}, + } + h.inactive <- struct{}{} + h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + return h +} + +// lock acquires the testSyncHooks mutex. +func (h *testSyncHooks) lock() { + select { + case <-h.active: + case <-h.inactive: + } +} + +// waitInactive waits for all goroutines to become inactive. +func (h *testSyncHooks) waitInactive() { + for { + <-h.inactive + if !h.unlock() { + break + } + } +} + +// unlock releases the testSyncHooks mutex. +// It reports whether any goroutines are active. +func (h *testSyncHooks) unlock() (active bool) { + // Look for a blocked goroutine which can be unblocked. + blocked := h.blocked[:0] + unblocked := false + for _, b := range h.blocked { + if !unblocked && b.f() { + unblocked = true + close(b.ch) + } else { + blocked = append(blocked, b) + } + } + h.blocked = blocked + + // Count goroutines blocked on condition variables. + condwait := 0 + for _, count := range h.condwait { + condwait += count + } + + if h.total > condwait+len(blocked) { + h.active <- struct{}{} + return true + } else { + h.inactive <- struct{}{} + return false + } +} + +// goRun starts a new goroutine. +func (h *testSyncHooks) goRun(f func()) { + h.lock() + h.total++ + h.unlock() + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + f() + }() +} + +// blockUntil indicates that a goroutine is blocked waiting for some condition to become true. +// It waits until f returns true before proceeding. +// +// Example usage: +// +// h.blockUntil(func() bool { +// // Is the context done yet? +// select { +// case <-ctx.Done(): +// default: +// return false +// } +// return true +// }) +// // Wait for the context to become done. +// <-ctx.Done() +// +// The function f passed to blockUntil must be non-blocking and idempotent. +func (h *testSyncHooks) blockUntil(f func() bool) { + if f() { + return + } + ch := make(chan struct{}) + h.lock() + h.blocked = append(h.blocked, &testBlockedGoroutine{ + f: f, + ch: ch, + }) + h.unlock() + <-ch +} + +// broadcast is sync.Cond.Broadcast. +func (h *testSyncHooks) condBroadcast(cond *sync.Cond) { + h.lock() + delete(h.condwait, cond) + h.unlock() + cond.Broadcast() +} + +// broadcast is sync.Cond.Wait. +func (h *testSyncHooks) condWait(cond *sync.Cond) { + h.lock() + h.condwait[cond]++ + h.unlock() +} + +// newTimer creates a new fake timer. +func (h *testSyncHooks) newTimer(d time.Duration) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + c: make(chan time.Time), + } + h.timers = append(h.timers, t) + return t +} + +// afterFunc creates a new fake AfterFunc timer. +func (h *testSyncHooks) afterFunc(d time.Duration, f func()) timer { + h.lock() + defer h.unlock() + t := &fakeTimer{ + hooks: h, + when: h.now.Add(d), + f: f, + } + h.timers = append(h.timers, t) + return t +} + +func (h *testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + t := h.afterFunc(d, cancel) + return ctx, func() { + t.Stop() + cancel() + } +} + +func (h *testSyncHooks) timeUntilEvent() time.Duration { + h.lock() + defer h.unlock() + var next time.Time + for _, t := range h.timers { + if next.IsZero() || t.when.Before(next) { + next = t.when + } + } + if d := next.Sub(h.now); d > 0 { + return d + } + return 0 +} + +// advance advances time and causes synthetic timers to fire. +func (h *testSyncHooks) advance(d time.Duration) { + h.lock() + defer h.unlock() + h.now = h.now.Add(d) + timers := h.timers[:0] + for _, t := range h.timers { + t := t // remove after go.mod depends on go1.22 + t.mu.Lock() + switch { + case t.when.After(h.now): + timers = append(timers, t) + case t.when.IsZero(): + // stopped timer + default: + t.when = time.Time{} + if t.c != nil { + close(t.c) + } + if t.f != nil { + h.total++ + go func() { + defer func() { + h.lock() + h.total-- + h.unlock() + }() + t.f() + }() + } + } + t.mu.Unlock() + } + h.timers = timers +} + +// A timer wraps a time.Timer, or a synthetic equivalent in tests. +// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires. +type timer interface { + C() <-chan time.Time + Stop() bool + Reset(d time.Duration) bool +} + +// timeTimer implements timer using real time. +type timeTimer struct { + t *time.Timer + c chan time.Time +} + +// newTimeTimer creates a new timer using real time. +func newTimeTimer(d time.Duration) timer { + ch := make(chan time.Time) + t := time.AfterFunc(d, func() { + close(ch) + }) + return &timeTimer{t, ch} +} + +// newTimeAfterFunc creates an AfterFunc timer using real time. +func newTimeAfterFunc(d time.Duration, f func()) timer { + return &timeTimer{ + t: time.AfterFunc(d, f), + } +} + +func (t timeTimer) C() <-chan time.Time { return t.c } +func (t timeTimer) Stop() bool { return t.t.Stop() } +func (t timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) } + +// fakeTimer implements timer using fake time. +type fakeTimer struct { + hooks *testSyncHooks + + mu sync.Mutex + when time.Time // when the timer will fire + c chan time.Time // closed when the timer fires; mutually exclusive with f + f func() // called when the timer fires; mutually exclusive with c +} + +func (t *fakeTimer) C() <-chan time.Time { return t.c } + +func (t *fakeTimer) Stop() bool { + t.mu.Lock() + defer t.mu.Unlock() + stopped := t.when.IsZero() + t.when = time.Time{} + return stopped +} + +func (t *fakeTimer) Reset(d time.Duration) bool { + if t.c != nil || t.f == nil { + panic("fakeTimer only supports Reset on AfterFunc timers") + } + t.mu.Lock() + defer t.mu.Unlock() + t.hooks.lock() + defer t.hooks.unlock() + active := !t.when.IsZero() + t.when = t.hooks.now.Add(d) + if !active { + t.hooks.timers = append(t.hooks.timers, t) + } + return active +} diff --git a/vendor/golang.org/x/net/http2/transport.go b/vendor/golang.org/x/net/http2/transport.go index c2a5b44b3d6..ce375c8c753 100644 --- a/vendor/golang.org/x/net/http2/transport.go +++ b/vendor/golang.org/x/net/http2/transport.go @@ -147,6 +147,12 @@ type Transport struct { // waiting for their turn. StrictMaxConcurrentStreams bool + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ReadIdleTimeout is the timeout after which a health check using ping // frame will be carried out if no frame is received on the connection. // Note that a ping response will is considered a received frame, so if @@ -178,6 +184,8 @@ type Transport struct { connPoolOnce sync.Once connPoolOrDef ClientConnPool // non-nil version of ConnPool + + syncHooks *testSyncHooks } func (t *Transport) maxHeaderListSize() uint32 { @@ -302,7 +310,7 @@ type ClientConn struct { readerErr error // set before readerDone is closed idleTimeout time.Duration // or 0 for never - idleTimer *time.Timer + idleTimer timer mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes @@ -344,6 +352,60 @@ type ClientConn struct { werr error // first write error that has occurred hbuf bytes.Buffer // HPACK encoder writes into this henc *hpack.Encoder + + syncHooks *testSyncHooks // can be nil +} + +// Hook points used for testing. +// Outside of tests, cc.syncHooks is nil and these all have minimal implementations. +// Inside tests, see the testSyncHooks function docs. + +// goRun starts a new goroutine. +func (cc *ClientConn) goRun(f func()) { + if cc.syncHooks != nil { + cc.syncHooks.goRun(f) + return + } + go f() +} + +// condBroadcast is cc.cond.Broadcast. +func (cc *ClientConn) condBroadcast() { + if cc.syncHooks != nil { + cc.syncHooks.condBroadcast(cc.cond) + } + cc.cond.Broadcast() +} + +// condWait is cc.cond.Wait. +func (cc *ClientConn) condWait() { + if cc.syncHooks != nil { + cc.syncHooks.condWait(cc.cond) + } + cc.cond.Wait() +} + +// newTimer creates a new time.Timer, or a synthetic timer in tests. +func (cc *ClientConn) newTimer(d time.Duration) timer { + if cc.syncHooks != nil { + return cc.syncHooks.newTimer(d) + } + return newTimeTimer(d) +} + +// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests. +func (cc *ClientConn) afterFunc(d time.Duration, f func()) timer { + if cc.syncHooks != nil { + return cc.syncHooks.afterFunc(d, f) + } + return newTimeAfterFunc(d, f) +} + +func (cc *ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) { + if cc.syncHooks != nil { + return cc.syncHooks.contextWithTimeout(ctx, d) + } + return context.WithTimeout(ctx, d) } // clientStream is the state for a single HTTP/2 stream. One of these @@ -425,7 +487,7 @@ func (cs *clientStream) abortStreamLocked(err error) { // TODO(dneil): Clean up tests where cs.cc.cond is nil. if cs.cc.cond != nil { // Wake up writeRequestBody if it is waiting on flow control. - cs.cc.cond.Broadcast() + cs.cc.condBroadcast() } } @@ -435,7 +497,7 @@ func (cs *clientStream) abortRequestBodyWrite() { defer cc.mu.Unlock() if cs.reqBody != nil && cs.reqBodyClosed == nil { cs.closeReqBodyLocked() - cc.cond.Broadcast() + cc.condBroadcast() } } @@ -445,10 +507,10 @@ func (cs *clientStream) closeReqBodyLocked() { } cs.reqBodyClosed = make(chan struct{}) reqBodyClosed := cs.reqBodyClosed - go func() { + cs.cc.goRun(func() { cs.reqBody.Close() close(reqBodyClosed) - }() + }) } type stickyErrWriter struct { @@ -537,15 +599,6 @@ func authorityAddr(scheme string, authority string) (addr string) { return net.JoinHostPort(host, port) } -var retryBackoffHook func(time.Duration) *time.Timer - -func backoffNewTimer(d time.Duration) *time.Timer { - if retryBackoffHook != nil { - return retryBackoffHook(d) - } - return time.NewTimer(d) -} - // RoundTripOpt is like RoundTrip, but takes options. func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { @@ -573,13 +626,27 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res backoff := float64(uint(1) << (uint(retry) - 1)) backoff += backoff * (0.1 * mathrand.Float64()) d := time.Second * time.Duration(backoff) - timer := backoffNewTimer(d) + var tm timer + if t.syncHooks != nil { + tm = t.syncHooks.newTimer(d) + t.syncHooks.blockUntil(func() bool { + select { + case <-tm.C(): + case <-req.Context().Done(): + default: + return false + } + return true + }) + } else { + tm = newTimeTimer(d) + } select { - case <-timer.C: + case <-tm.C(): t.vlogf("RoundTrip retrying after failure: %v", roundTripErr) continue case <-req.Context().Done(): - timer.Stop() + tm.Stop() err = req.Context().Err() } } @@ -658,6 +725,9 @@ func canRetryError(err error) bool { } func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*ClientConn, error) { + if t.syncHooks != nil { + return t.newClientConn(nil, singleUse, t.syncHooks) + } host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -666,7 +736,7 @@ func (t *Transport) dialClientConn(ctx context.Context, addr string, singleUse b if err != nil { return nil, err } - return t.newClientConn(tconn, singleUse) + return t.newClientConn(tconn, singleUse, nil) } func (t *Transport) newTLSConfig(host string) *tls.Config { @@ -732,10 +802,10 @@ func (t *Transport) maxEncoderHeaderTableSize() uint32 { } func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { - return t.newClientConn(c, t.disableKeepAlives()) + return t.newClientConn(c, t.disableKeepAlives(), nil) } -func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { +func (t *Transport) newClientConn(c net.Conn, singleUse bool, hooks *testSyncHooks) (*ClientConn, error) { cc := &ClientConn{ t: t, tconn: c, @@ -750,10 +820,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro wantSettingsAck: true, pings: make(map[[8]byte]chan struct{}), reqHeaderMu: make(chan struct{}, 1), + syncHooks: hooks, + } + if hooks != nil { + hooks.newclientconn(cc) + c = cc.tconn } if d := t.idleConnTimeout(); d != 0 { cc.idleTimeout = d - cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) + cc.idleTimer = cc.afterFunc(d, cc.onIdleTimeout) } if VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -818,7 +893,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro return nil, cc.werr } - go cc.readLoop() + cc.goRun(cc.readLoop) return cc, nil } @@ -826,7 +901,7 @@ func (cc *ClientConn) healthCheck() { pingTimeout := cc.t.pingTimeout() // We don't need to periodically ping in the health check, because the readLoop of ClientConn will // trigger the healthCheck again if there is no frame received. - ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout) defer cancel() cc.vlogf("http2: Transport sending health check") err := cc.Ping(ctx) @@ -1056,7 +1131,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { // Wait for all in-flight streams to complete or connection to close done := make(chan struct{}) cancelled := false // guarded by cc.mu - go func() { + cc.goRun(func() { cc.mu.Lock() defer cc.mu.Unlock() for { @@ -1068,9 +1143,9 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { if cancelled { break } - cc.cond.Wait() + cc.condWait() } - }() + }) shutdownEnterWaitStateHook() select { case <-done: @@ -1080,7 +1155,7 @@ func (cc *ClientConn) Shutdown(ctx context.Context) error { cc.mu.Lock() // Free the goroutine above cancelled = true - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() return ctx.Err() } @@ -1118,7 +1193,7 @@ func (cc *ClientConn) closeForError(err error) { for _, cs := range cc.streams { cs.abortStreamLocked(err) } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() cc.closeConn() } @@ -1215,6 +1290,10 @@ func (cc *ClientConn) decrStreamReservationsLocked() { } func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { + return cc.roundTrip(req, nil) +} + +func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) (*http.Response, error) { ctx := req.Context() cs := &clientStream{ cc: cc, @@ -1229,9 +1308,23 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { respHeaderRecv: make(chan struct{}), donec: make(chan struct{}), } - go cs.doRequest(req) + cc.goRun(func() { + cs.doRequest(req) + }) waitDone := func() error { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.donec: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.donec: return nil @@ -1292,7 +1385,24 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return err } + if streamf != nil { + streamf(cs) + } + for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.respHeaderRecv: return handleResponseHeaders() @@ -1348,6 +1458,21 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { if cc.reqHeaderMu == nil { panic("RoundTrip on uninitialized ClientConn") // for tests } + var newStreamHook func(*clientStream) + if cc.syncHooks != nil { + newStreamHook = cc.syncHooks.newstream + cc.syncHooks.blockUntil(func() bool { + select { + case cc.reqHeaderMu <- struct{}{}: + <-cc.reqHeaderMu + case <-cs.reqCancel: + case <-ctx.Done(): + default: + return false + } + return true + }) + } select { case cc.reqHeaderMu <- struct{}{}: case <-cs.reqCancel: @@ -1372,6 +1497,10 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { } cc.mu.Unlock() + if newStreamHook != nil { + newStreamHook(cs) + } + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && @@ -1452,15 +1581,30 @@ func (cs *clientStream) writeRequest(req *http.Request) (err error) { var respHeaderTimer <-chan time.Time var respHeaderRecv chan struct{} if d := cc.responseHeaderTimeout(); d != 0 { - timer := time.NewTimer(d) + timer := cc.newTimer(d) defer timer.Stop() - respHeaderTimer = timer.C + respHeaderTimer = timer.C() respHeaderRecv = cs.respHeaderRecv } // Wait until the peer half-closes its end of the stream, // or until the request is aborted (via context, error, or otherwise), // whichever comes first. for { + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-cs.peerClosed: + case <-respHeaderTimer: + case <-respHeaderRecv: + case <-cs.abort: + case <-ctx.Done(): + case <-cs.reqCancel: + default: + return false + } + return true + }) + } select { case <-cs.peerClosed: return nil @@ -1609,7 +1753,7 @@ func (cc *ClientConn) awaitOpenSlotForStreamLocked(cs *clientStream) error { return nil } cc.pendingRequests++ - cc.cond.Wait() + cc.condWait() cc.pendingRequests-- select { case <-cs.abort: @@ -1871,8 +2015,24 @@ func (cs *clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) cs.flow.take(take) return take, nil } - cc.cond.Wait() + cc.condWait() + } +} + +func validateHeaders(hdrs http.Header) string { + for k, vv := range hdrs { + if !httpguts.ValidHeaderFieldName(k) { + return fmt.Sprintf("name %q", k) + } + for _, v := range vv { + if !httpguts.ValidHeaderFieldValue(v) { + // Don't include the value in the error, + // because it may be sensitive. + return fmt.Sprintf("value for header %q", k) + } + } } + return "" } var errNilRequestURL = errors.New("http2: Request.URI is nil") @@ -1912,19 +2072,14 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail } } - // Check for any invalid headers and return an error before we + // Check for any invalid headers+trailers and return an error before we // potentially pollute our hpack state. (We want to be able to // continue to reuse the hpack encoder for future requests) - for k, vv := range req.Header { - if !httpguts.ValidHeaderFieldName(k) { - return nil, fmt.Errorf("invalid HTTP header name %q", k) - } - for _, v := range vv { - if !httpguts.ValidHeaderFieldValue(v) { - // Don't include the value in the error, because it may be sensitive. - return nil, fmt.Errorf("invalid HTTP header value for header %q", k) - } - } + if err := validateHeaders(req.Header); err != "" { + return nil, fmt.Errorf("invalid HTTP header %s", err) + } + if err := validateHeaders(req.Trailer); err != "" { + return nil, fmt.Errorf("invalid HTTP trailer %s", err) } enumerateHeaders := func(f func(name, value string)) { @@ -2143,7 +2298,7 @@ func (cc *ClientConn) forgetStreamID(id uint32) { } // Wake up writeRequestBody via clientStream.awaitFlowControl and // wake up RoundTrip if there is a pending request. - cc.cond.Broadcast() + cc.condBroadcast() closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 { @@ -2231,7 +2386,7 @@ func (rl *clientConnReadLoop) cleanup() { cs.abortStreamLocked(err) } } - cc.cond.Broadcast() + cc.condBroadcast() cc.mu.Unlock() } @@ -2266,10 +2421,9 @@ func (rl *clientConnReadLoop) run() error { cc := rl.cc gotSettings := false readIdleTimeout := cc.t.ReadIdleTimeout - var t *time.Timer + var t timer if readIdleTimeout != 0 { - t = time.AfterFunc(readIdleTimeout, cc.healthCheck) - defer t.Stop() + t = cc.afterFunc(readIdleTimeout, cc.healthCheck) } for { f, err := cc.fr.ReadFrame() @@ -2684,7 +2838,7 @@ func (rl *clientConnReadLoop) processData(f *DataFrame) error { }) return nil } - if !cs.firstByte { + if !cs.pastHeaders { cc.logf("protocol error: received DATA before a HEADERS frame") rl.endStreamError(cs, StreamError{ StreamID: f.StreamID, @@ -2867,7 +3021,7 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error { for _, cs := range cc.streams { cs.flow.add(delta) } - cc.cond.Broadcast() + cc.condBroadcast() cc.initialWindowSize = s.Val case SettingHeaderTableSize: @@ -2922,7 +3076,7 @@ func (rl *clientConnReadLoop) processWindowUpdate(f *WindowUpdateFrame) error { return ConnectionError(ErrCodeFlowControl) } - cc.cond.Broadcast() + cc.condBroadcast() return nil } @@ -2964,24 +3118,38 @@ func (cc *ClientConn) Ping(ctx context.Context) error { } cc.mu.Unlock() } - errc := make(chan error, 1) - go func() { + var pingError error + errc := make(chan struct{}) + cc.goRun(func() { cc.wmu.Lock() defer cc.wmu.Unlock() - if err := cc.fr.WritePing(false, p); err != nil { - errc <- err + if pingError = cc.fr.WritePing(false, p); pingError != nil { + close(errc) return } - if err := cc.bw.Flush(); err != nil { - errc <- err + if pingError = cc.bw.Flush(); pingError != nil { + close(errc) return } - }() + }) + if cc.syncHooks != nil { + cc.syncHooks.blockUntil(func() bool { + select { + case <-c: + case <-errc: + case <-ctx.Done(): + case <-cc.readerDone: + default: + return false + } + return true + }) + } select { case <-c: return nil - case err := <-errc: - return err + case <-errc: + return pingError case <-ctx.Done(): return ctx.Err() case <-cc.readerDone: @@ -3150,9 +3318,17 @@ func (rt noDialH2RoundTripper) RoundTrip(req *http.Request) (*http.Response, err } func (t *Transport) idleConnTimeout() time.Duration { + // to keep things backwards compatible, we use non-zero values of + // IdleConnTimeout, followed by using the IdleConnTimeout on the underlying + // http1 transport, followed by 0 + if t.IdleConnTimeout != 0 { + return t.IdleConnTimeout + } + if t.t1 != nil { return t.t1.IdleConnTimeout } + return 0 } diff --git a/vendor/modules.txt b/vendor/modules.txt index 5b9bdbc748e..0e09335b142 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -730,7 +730,7 @@ golang.org/x/crypto/internal/poly1305 golang.org/x/exp/constraints golang.org/x/exp/maps golang.org/x/exp/slices -# golang.org/x/net v0.22.0 +# golang.org/x/net v0.23.0 ## explicit; go 1.18 golang.org/x/net/context golang.org/x/net/html