Skip to content

Commit

Permalink
add Delayed round-tripper (#12)
Browse files Browse the repository at this point in the history
* add Delayed round-tripper

* split roundtripper into more specific tasks
  • Loading branch information
xiam authored Nov 22, 2024
1 parent 23bf6fa commit ab7e622
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 0 deletions.
70 changes: 70 additions & 0 deletions delayed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package transport

import (
"fmt"
"math/rand"
"net/http"
"time"
)

// DelayedRequest is a middleware that delays requests, useful when testing
// timeouts while waiting on a request to be sent upstream.
func DelayedRequest(requestDelayMin, requestDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper {
if requestDelayMin > requestDelayMax {
panic(fmt.Sprintf("requestDelayMin %v is greater than requestDelayMax %v", requestDelayMin, requestDelayMax))
}
return delayedRoundTripper(randDelay(requestDelayMin, requestDelayMax), 0)
}

// DelayedResponse is a middleware that delays responses, useful when testing
// timeouts after upstream has processed the request, the response is hold back
// until the delay is over.
func DelayedResponse(responseDelayMin, responseDelayMax time.Duration) func(http.RoundTripper) http.RoundTripper {
if responseDelayMin > responseDelayMax {
panic(fmt.Sprintf("responseDelayMin %v is greater than responseDelayMax %v", responseDelayMin, responseDelayMax))
}
return delayedRoundTripper(0, randDelay(responseDelayMin, responseDelayMax))
}

func delayedRoundTripper(requestDelay, responseDelay time.Duration) func(http.RoundTripper) http.RoundTripper {
return func(next http.RoundTripper) http.RoundTripper {
return RoundTripFunc(func(req *http.Request) (*http.Response, error) {
ctx := req.Context()

// wait before sending request
if requestDelay > 0 {
ticker := time.NewTicker(requestDelay)
defer ticker.Stop()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
}
}

res, err := next.RoundTrip(req)

// wait before sending response body
if responseDelay > 0 {
ticker := time.NewTicker(responseDelay)
defer ticker.Stop()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
}
}

return res, err
})
}
}

func randDelay(min, max time.Duration) time.Duration {
if min >= max {
return min
}
return min + time.Duration(rand.Int63n(int64(max-min)))
}
207 changes: 207 additions & 0 deletions delayed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package transport_test

import (
"context"
"fmt"
"io/ioutil"
"testing"
"time"

"net/http"
"net/http/httptest"

"github.com/go-chi/transport"
)

func TestDelayed(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok")
}))
defer server.Close()

t.Run("default config", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.DelayedRequest(0, 0),
transport.DelayedResponse(0, 0),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
resp, err := client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

t.Logf("elapsed time: %v", timeElapsed)

if resp.StatusCode != 200 {
t.Fatal("expected some header, but did not receive")
}

buf, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}

t.Logf("response: %s", string(buf))
})

t.Run("delayed response", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.DelayedResponse(100*time.Millisecond, 200*time.Millisecond),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("delayed connect", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.DelayedResponse(
100*time.Millisecond,
200*time.Millisecond,
),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("delayed request and response", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.DelayedRequest(50*time.Millisecond, 100*time.Millisecond),
transport.DelayedResponse(50*time.Millisecond, 100*time.Millisecond),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}
})

t.Run("chained transport", func(t *testing.T) {
var customTransportHit bool

customTransport := transport.RoundTripFunc(func(req *http.Request) (*http.Response, error) {
customTransportHit = true

return http.DefaultTransport.RoundTrip(req)
})

client := &http.Client{
Transport: transport.Chain(
customTransport,
transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

timeStart := time.Now()
_, err = client.Do(request)
if err != nil {
t.Fatal(err)
}
timeElapsed := time.Since(timeStart)

if timeElapsed < 100*time.Millisecond {
t.Fatalf("expected at least 100ms delay, but got %v", timeElapsed)
}

if customTransportHit == false {
t.Fatal("expected custom transport to be hit, but it was not")
}
})

t.Run("honor request context", func(t *testing.T) {
client := &http.Client{
Transport: transport.Chain(
nil,
transport.DelayedRequest(100*time.Millisecond, 200*time.Millisecond),
),
}

request, err := http.NewRequest("GET", server.URL, nil)
if err != nil {
t.Fatal(err)
}

ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()

request = request.WithContext(ctx)

timeStart := time.Now()
_, err = client.Do(request)
timeElapsed := time.Since(timeStart)

if err == nil {
t.Fatalf("expected error, but got none")
}

if timeElapsed < 50*time.Millisecond {
t.Fatalf("expected at least 50ms delay, but got %v", timeElapsed)
}

if timeElapsed > 100*time.Millisecond {
t.Fatalf("expected less than 100ms delay, but got %v", timeElapsed)
}
})
}

0 comments on commit ab7e622

Please sign in to comment.