Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default retry policy for certificate verification errors and bad request headers #210

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cert_error_go119.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !go1.20
// +build !go1.20

package retryablehttp

import "crypto/x509"

func isCertError(err error) bool {
_, ok := err.(x509.UnknownAuthorityError)
return ok
}
14 changes: 14 additions & 0 deletions cert_error_go120.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build go1.20
// +build go1.20

package retryablehttp

import "crypto/tls"

func isCertError(err error) bool {
_, ok := err.(*tls.CertificateVerificationError)
return ok
}
19 changes: 14 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ package retryablehttp
import (
"bytes"
"context"
"crypto/x509"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -62,6 +61,10 @@ var (
// limit the size we consume to respReadLimit.
respReadLimit = int64(4096)

// timeNow sets the function that returns the current time.
// This defaults to time.Now. Changes to this should only be done in tests.
timeNow = time.Now

// A regular expression to match the error returned by net/http when the
// configured number of redirects is exhausted. This error isn't typed
// specifically so we resort to matching on the error string.
Expand All @@ -72,9 +75,10 @@ var (
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)

// timeNow sets the function that returns the current time.
// This defaults to time.Now. Changes to this should only be done in tests.
timeNow = time.Now
// A regular expression to match the error returned by net/http when a
// request header or value is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
invalidHeaderErrorRe = regexp.MustCompile(`invalid header`)

// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
Expand Down Expand Up @@ -501,11 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
return false, v
}

// Don't retry if the error was due to an invalid header.
if invalidHeaderErrorRe.MatchString(v.Error()) {
return false, v
}

// Don't retry if the error was due to TLS cert verification failure.
if notTrustedErrorRe.MatchString(v.Error()) {
return false, v
}
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
if isCertError(v.Err) {
return false, v
}
}
Expand Down
54 changes: 54 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,60 @@ func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) {
}
}

func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

attempts := 0
client := NewClient()
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
return DefaultRetryPolicy(context.TODO(), resp, err)
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.Header.Set("Header-Name-\033", "header value")
_, err = client.StandardClient().Do(req)
if err == nil {
t.Fatalf("expected header error, got nil")
}
if attempts != 1 {
t.Fatalf("expected 1 attempt, got %d", attempts)
}
}

func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

attempts := 0
client := NewClient()
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
return DefaultRetryPolicy(context.TODO(), resp, err)
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.Header.Set("Header-Name", "bad header value \033")
_, err = client.StandardClient().Do(req)
if err == nil {
t.Fatalf("expected header value error, got nil")
}
if attempts != 1 {
t.Fatalf("expected 1 attempt, got %d", attempts)
}
}

func TestClient_CheckRetryStop(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "test_500_body", http.StatusInternalServerError)
Expand Down