From be3176f5c6f4306e379d0c400db6279dc1b97315 Mon Sep 17 00:00:00 2001 From: Eric Wollesen Date: Mon, 9 Jun 2025 12:50:58 -0600 Subject: [PATCH] use a default timeout on requests that don't specify one This behavior was inadvertently lost in e282fb3cfb0004c9a52e014003a00c42b9fd462e, but is restored here, while still allowing for per-request timeouts to be specified. --- client/client.go | 25 +++++++++++++++++------ client/client_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/client/client.go b/client/client.go index fe9298f97c..133611a370 100644 --- a/client/client.go +++ b/client/client.go @@ -10,6 +10,7 @@ import ( "net/url" "reflect" "strings" + "time" "unicode/utf8" "github.com/tidepool-org/platform/errors" @@ -35,6 +36,9 @@ type Client struct { address string userAgent string errorResponseParser ErrorResponseParser + + // DefaultRequestTimeout applies to requests whose context doesn't include a timeout. + DefaultRequestTimeout time.Duration } func New(cfg *Config) (*Client, error) { @@ -49,12 +53,15 @@ func NewWithErrorParser(cfg *Config, errorResponseParser ErrorResponseParser) (* } return &Client{ - address: cfg.Address, - userAgent: cfg.UserAgent, - errorResponseParser: errorResponseParser, + address: cfg.Address, + userAgent: cfg.UserAgent, + errorResponseParser: errorResponseParser, + DefaultRequestTimeout: DefaultRequestTimeout, }, nil } +const DefaultRequestTimeout = time.Minute + func (c *Client) ConstructURL(paths ...string) string { segments := []string{} for _, path := range paths { @@ -92,6 +99,14 @@ func (c *Client) RequestStreamWithHTTPClient(ctx context.Context, method string, return nil, err } + reqCtx := req.Context() + if _, ok := reqCtx.Deadline(); !ok { + toCtx, cancel := context.WithTimeout(reqCtx, c.DefaultRequestTimeout) + defer cancel() + req = req.WithContext(toCtx) + ctx = toCtx + } + res, err := httpClient.Do(req) if err != nil { return nil, errors.Wrapf(err, "unable to perform request to %s %s", method, url) @@ -152,13 +167,11 @@ func (c *Client) createRequest(ctx context.Context, method string, url string, m } } - req, err := http.NewRequest(method, url, body) + req, err := http.NewRequestWithContext(ctx, method, url, body) if err != nil { return nil, errors.Wrapf(err, "unable to create request to %s %s", method, url) } - req = req.WithContext(ctx) - for _, mutator := range mutators { if err = mutator.MutateRequest(req); err != nil { return nil, errors.Wrapf(err, "unable to mutate request to %s %s", method, url) diff --git a/client/client_test.go b/client/client_test.go index 29bd247b8b..9de429771f 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" "strings" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -350,6 +351,40 @@ var _ = Describe("Client", func() { Expect(reader).To(BeNil()) }) + Context("request timeouts", func() { + It("aren't overwritten by the default", func() { + deadline := time.Now().Add(time.Second) + toCtx, cancel := context.WithDeadline(ctx, deadline) + defer cancel() + server.AppendHandlers(RespondWith(http.StatusNoContent, nil)) + + rc, err := clnt.RequestStreamWithHTTPClient(toCtx, method, url, nil, nil, nil, httpClient) + Expect(err).To(Succeed()) + if rc != nil { + defer rc.Close() + } + t, found := toCtx.Deadline() + Expect(found).To(Equal(true)) + Expect(t).To(Equal(deadline)) + }) + + It("uses the default", func() { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + server.AppendHandlers(waitThenReturn(ctx, 10*time.Second)) + shortTimeout := 10 * time.Millisecond + clnt.DefaultRequestTimeout = shortTimeout + + start := time.Now() + rc, err := clnt.RequestStreamWithHTTPClient(ctx, method, url, nil, nil, nil, httpClient) + Expect(err).To(MatchError(ContainSubstring("context deadline exceeded"))) + if rc != nil { + defer rc.Close() + } + Expect(time.Since(start) < 2*shortTimeout).To(Equal(true)) + }) + }) + Context("with a successful response and no request body, but inspector returns error", func() { var responseErr error var errorInspector *requestTest.ResponseInspector @@ -1312,3 +1347,15 @@ var _ = Describe("Client", func() { }) }) }) + +func waitThenReturn(ctx context.Context, dur time.Duration) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + select { + case <-time.After(dur): // wait a while... + RespondWith(http.StatusInternalServerError, nil) + case <-ctx.Done(): // ...unless the test is ended + RespondWith(http.StatusNoContent, nil) + } + } + +}