Skip to content

Commit

Permalink
Merge pull request #210 from tomclegg/noretry-header-cert
Browse files Browse the repository at this point in the history
Fix default retry policy for certificate verification errors and bad request headers
  • Loading branch information
manicminer authored May 9, 2024
2 parents 4fb315e + eb08cce commit 1643719
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 5 deletions.
14 changes: 14 additions & 0 deletions cert_error_go119.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !go1.20
// +build !go1.20

package retryablehttp

import "crypto/x509"

func isCertError(err error) bool {
_, ok := err.(x509.UnknownAuthorityError)
return ok
}
14 changes: 14 additions & 0 deletions cert_error_go120.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build go1.20
// +build go1.20

package retryablehttp

import "crypto/tls"

func isCertError(err error) bool {
_, ok := err.(*tls.CertificateVerificationError)
return ok
}
19 changes: 14 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ package retryablehttp
import (
"bytes"
"context"
"crypto/x509"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -62,6 +61,10 @@ var (
// limit the size we consume to respReadLimit.
respReadLimit = int64(4096)

// timeNow sets the function that returns the current time.
// This defaults to time.Now. Changes to this should only be done in tests.
timeNow = time.Now

// A regular expression to match the error returned by net/http when the
// configured number of redirects is exhausted. This error isn't typed
// specifically so we resort to matching on the error string.
Expand All @@ -72,9 +75,10 @@ var (
// specifically so we resort to matching on the error string.
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)

// timeNow sets the function that returns the current time.
// This defaults to time.Now. Changes to this should only be done in tests.
timeNow = time.Now
// A regular expression to match the error returned by net/http when a
// request header or value is invalid. This error isn't typed
// specifically so we resort to matching on the error string.
invalidHeaderErrorRe = regexp.MustCompile(`invalid header`)

// A regular expression to match the error returned by net/http when the
// TLS certificate is not trusted. This error isn't typed
Expand Down Expand Up @@ -501,11 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) {
return false, v
}

// Don't retry if the error was due to an invalid header.
if invalidHeaderErrorRe.MatchString(v.Error()) {
return false, v
}

// Don't retry if the error was due to TLS cert verification failure.
if notTrustedErrorRe.MatchString(v.Error()) {
return false, v
}
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
if isCertError(v.Err) {
return false, v
}
}
Expand Down
54 changes: 54 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,60 @@ func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) {
}
}

func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

attempts := 0
client := NewClient()
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
return DefaultRetryPolicy(context.TODO(), resp, err)
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.Header.Set("Header-Name-\033", "header value")
_, err = client.StandardClient().Do(req)
if err == nil {
t.Fatalf("expected header error, got nil")
}
if attempts != 1 {
t.Fatalf("expected 1 attempt, got %d", attempts)
}
}

func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

attempts := 0
client := NewClient()
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
return DefaultRetryPolicy(context.TODO(), resp, err)
}

req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}
req.Header.Set("Header-Name", "bad header value \033")
_, err = client.StandardClient().Do(req)
if err == nil {
t.Fatalf("expected header value error, got nil")
}
if attempts != 1 {
t.Fatalf("expected 1 attempt, got %d", attempts)
}
}

func TestClient_CheckRetryStop(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "test_500_body", http.StatusInternalServerError)
Expand Down

0 comments on commit 1643719

Please sign in to comment.