-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add Delayed round-tripper * split roundtripper into more specific tasks
- Loading branch information
Showing
2 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package transport | ||
|
||
import ( | ||
"fmt" | ||
"math/rand" | ||
"net/http" | ||
"time" | ||
) | ||
|
||
// DelayedRequest is a middleware that delays requests, useful when testing | ||
// timeouts while waiting on a request to be sent upstream. | ||
func DelayedRequest(requestDelayMin, requestDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper { | ||
if requestDelayMin > requestDelayMax { | ||
panic(fmt.Sprintf("requestDelayMin %v is greater than requestDelayMax %v", requestDelayMin, requestDelayMax)) | ||
} | ||
return delayedRoundTripper(randDelay(requestDelayMin, requestDelayMax), 0) | ||
} | ||
|
||
// DelayedResponse is a middleware that delays responses, useful when testing | ||
// timeouts after upstream has processed the request, the response is hold back | ||
// until the delay is over. | ||
func DelayedResponse(responseDelayMin, responseDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper { | ||
if responseDelayMin > responseDelayMax { | ||
panic(fmt.Sprintf("responseDelayMin %v is greater than responseDelayMax %v", responseDelayMin, responseDelayMax)) | ||
} | ||
return delayedRoundTripper(0, randDelay(responseDelayMin, responseDelayMax)) | ||
} | ||
|
||
func delayedRoundTripper(requestDelay, responseDelay time.Duration) func(http.RoundTripper) http.RoundTripper { | ||
return func(next http.RoundTripper) http.RoundTripper { | ||
return RoundTripFunc(func(req *http.Request) (*http.Response, error) { | ||
ctx := req.Context() | ||
|
||
// wait before sending request | ||
if requestDelay > 0 { | ||
ticker := time.NewTicker(requestDelay) | ||
defer ticker.Stop() | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
case <-ticker.C: | ||
} | ||
} | ||
|
||
res, err := next.RoundTrip(req) | ||
|
||
// wait before sending response body | ||
if responseDelay > 0 { | ||
ticker := time.NewTicker(responseDelay) | ||
defer ticker.Stop() | ||
|
||
select { | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
case <-ticker.C: | ||
} | ||
} | ||
|
||
return res, err | ||
}) | ||
} | ||
} | ||
|
||
func randDelay(min, max time.Duration) time.Duration { | ||
if min >= max { | ||
return min | ||
} | ||
return min + time.Duration(rand.Int63n(int64(max-min))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
package transport_test | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io/ioutil" | ||
"testing" | ||
"time" | ||
|
||
"net/http" | ||
"net/http/httptest" | ||
|
||
"github.com/go-chi/transport" | ||
) | ||
|
||
func TestDelayed(t *testing.T) { | ||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
fmt.Fprintf(w, "ok") | ||
})) | ||
defer server.Close() | ||
|
||
t.Run("default config", func(t *testing.T) { | ||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
nil, | ||
transport.DelayedRequest(0, 0), | ||
transport.DelayedResponse(0, 0), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
timeStart := time.Now() | ||
resp, err := client.Do(request) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
timeElapsed := time.Since(timeStart) | ||
|
||
t.Logf("elapsed time: %v", timeElapsed) | ||
|
||
if resp.StatusCode != 200 { | ||
t.Fatal("expected some header, but did not receive") | ||
} | ||
|
||
buf, err := ioutil.ReadAll(resp.Body) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
t.Logf("response: %s", string(buf)) | ||
}) | ||
|
||
t.Run("delayed response", func(t *testing.T) { | ||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
nil, | ||
transport.DelayedResponse(100*time.Millisecond, 200*time.Millisecond), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
timeStart := time.Now() | ||
_, err = client.Do(request) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
timeElapsed := time.Since(timeStart) | ||
|
||
if timeElapsed < 100*time.Millisecond { | ||
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) | ||
} | ||
}) | ||
|
||
t.Run("delayed connect", func(t *testing.T) { | ||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
nil, | ||
transport.DelayedResponse( | ||
100*time.Millisecond, | ||
200*time.Millisecond, | ||
), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
timeStart := time.Now() | ||
_, err = client.Do(request) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
timeElapsed := time.Since(timeStart) | ||
|
||
if timeElapsed < 100*time.Millisecond { | ||
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) | ||
} | ||
}) | ||
|
||
t.Run("delayed request and response", func(t *testing.T) { | ||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
nil, | ||
transport.DelayedRequest(50*time.Millisecond, 100*time.Millisecond), | ||
transport.DelayedResponse(50*time.Millisecond, 100*time.Millisecond), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
timeStart := time.Now() | ||
_, err = client.Do(request) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
timeElapsed := time.Since(timeStart) | ||
|
||
if timeElapsed < 100*time.Millisecond { | ||
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) | ||
} | ||
}) | ||
|
||
t.Run("chained transport", func(t *testing.T) { | ||
var customTransportHit bool | ||
|
||
customTransport := transport.RoundTripFunc(func(req *http.Request) (*http.Response, error) { | ||
customTransportHit = true | ||
|
||
return http.DefaultTransport.RoundTrip(req) | ||
}) | ||
|
||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
customTransport, | ||
transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
timeStart := time.Now() | ||
_, err = client.Do(request) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
timeElapsed := time.Since(timeStart) | ||
|
||
if timeElapsed < 100*time.Millisecond { | ||
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed) | ||
} | ||
|
||
if customTransportHit == false { | ||
t.Fatal("expected custom transport to be hit, but it was not") | ||
} | ||
}) | ||
|
||
t.Run("honor request context", func(t *testing.T) { | ||
client := &http.Client{ | ||
Transport: transport.Chain( | ||
nil, | ||
transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond), | ||
), | ||
} | ||
|
||
request, err := http.NewRequest("GET", server.URL, nil) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) | ||
defer cancel() | ||
|
||
request = request.WithContext(ctx) | ||
|
||
timeStart := time.Now() | ||
_, err = client.Do(request) | ||
timeElapsed := time.Since(timeStart) | ||
|
||
if err == nil { | ||
t.Fatalf("expected error, but got none") | ||
} | ||
|
||
if timeElapsed < 50*time.Millisecond { | ||
t.Fatalf("expected at least 50ms delay, but got %v", timeElapsed) | ||
} | ||
|
||
if timeElapsed > 100*time.Millisecond { | ||
t.Fatalf("expected less than 100ms delay, but got %v", timeElapsed) | ||
} | ||
}) | ||
} |