Skip to content

Commit

Permalink
fix(#23): persist context data across body updates
Browse files Browse the repository at this point in the history
  • Loading branch information
h2non committed Mar 17, 2017
1 parent 00768d4 commit b306884
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 11 deletions.
47 changes: 47 additions & 0 deletions _examples/context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package main

import (
"fmt"

"gopkg.in/h2non/gentleman.v1"
)

func main() {
// Create new request instance
req := gentleman.NewRequest()
req.Method("GET")

// Define target URL
req.URL("http://httpbin.org/headers")

// Set a new header field
req.SetHeader("Client", "gentleman")

// Set sample context data
req.Context.Set("foo", "bar")
req.Context.Set("bar", "baz")

// Set sample body as string
req.BodyString("hello, gentleman!")

// Output all context data
fmt.Println(req.Context.GetAll())

// Perform the request
res, err := req.Do()
if err != nil {
fmt.Printf("Request error: %s\n", err)
return
}
if !res.Ok {
fmt.Printf("Invalid server response: %d\n", res.StatusCode)
return
}

// Set sample context data
fmt.Println(req.Context.GetString("foo"))
fmt.Println(req.Context.GetString("bar"))

// Reads the whole body and returns it as string
fmt.Printf("Body: %s", res.String())
}
30 changes: 26 additions & 4 deletions context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,23 @@ func (c *Context) Copy(req *http.Request) {
req.Body = getContextReadCloser(c.Request).Clone()
}

// WrapBody wraps and copies the current context data to a new request body.
// Since context metadata is "overloaded" as part of the request body reader
// calling this method is mandatory for any new request body.
func (c *Context) WrapBody(body io.ReadCloser) ReadCloser {
ctx := getContextReadCloser(c.Request).Context()
newBody := wrapContextReadCloser(body)
newBody.SetContext(ctx)
return newBody
}

// ReadCloser augments the io.ReadCloser interface
// with a Context() method
type ReadCloser interface {
io.ReadCloser
Clone() ReadCloser
Context() map[interface{}]interface{}
SetContext(map[interface{}]interface{})
}

type contextReadCloser struct {
Expand All @@ -174,18 +185,29 @@ func (crc *contextReadCloser) Clone() ReadCloser {
return clone
}

func getContextReadCloser(req *http.Request) ReadCloser {
crc, ok := req.Body.(ReadCloser)
func (crc *contextReadCloser) SetContext(store map[interface{}]interface{}) {
for key, value := range store {
crc.store[key] = value
}
}

func wrapContextReadCloser(body io.ReadCloser) ReadCloser {
crc, ok := body.(ReadCloser)
if !ok {
crc = &contextReadCloser{
ReadCloser: req.Body,
ReadCloser: body,
store: make(map[interface{}]interface{}),
}
req.Body = crc
}
return crc
}

func getContextReadCloser(req *http.Request) ReadCloser {
body := wrapContextReadCloser(req.Body)
req.Body = body
return body
}

func createRequest() *http.Request {
req := &http.Request{
Method: "GET",
Expand Down
8 changes: 4 additions & 4 deletions plugins/body/body.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func String(data string) p.Plugin {
return p.NewRequestPlugin(func(ctx *c.Context, h c.Handler) {
ctx.Request.Method = getMethod(ctx)
ctx.Request.Body = utils.StringReader(data)
ctx.Request.Body = ctx.WrapBody(utils.StringReader(data))
ctx.Request.ContentLength = int64(bytes.NewBufferString(data).Len())
h.Next(ctx)
})
Expand All @@ -42,7 +42,7 @@ func JSON(data interface{}) p.Plugin {
}

ctx.Request.Method = getMethod(ctx)
ctx.Request.Body = ioutil.NopCloser(buf)
ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(buf))
ctx.Request.ContentLength = int64(buf.Len())
ctx.Request.Header.Set("Content-Type", "application/json")

Expand All @@ -69,7 +69,7 @@ func XML(data interface{}) p.Plugin {
}

ctx.Request.Method = getMethod(ctx)
ctx.Request.Body = ioutil.NopCloser(buf)
ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(buf))
ctx.Request.ContentLength = int64(buf.Len())
ctx.Request.Header.Set("Content-Type", "application/xml")

Expand Down Expand Up @@ -98,7 +98,7 @@ func Reader(body io.Reader) p.Plugin {
}
}

req.Body = rc
req.Body = ctx.WrapBody(rc)
ctx.Request.Method = getMethod(ctx)

h.Next(ctx)
Expand Down
30 changes: 28 additions & 2 deletions plugins/body/body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package body

import (
"bytes"
"github.com/nbio/st"
"gopkg.in/h2non/gentleman.v1/context"
"io/ioutil"
"testing"

"github.com/nbio/st"
"gopkg.in/h2non/gentleman.v1/context"
)

func TestBodyJSONEncodeMap(t *testing.T) {
Expand Down Expand Up @@ -122,6 +123,31 @@ func TestBodyReader(t *testing.T) {
st.Expect(t, string(buf), "foo bar")
}

func TestBodyReaderContextDataSharing(t *testing.T) {
ctx := context.New()
ctx.Request.Method = "POST"
fn := newHandler()

// Set sample context data
ctx.Set("foo", "bar")
ctx.Set("bar", "baz")

reader := bytes.NewReader([]byte("foo bar"))
Reader(reader).Exec("request", ctx, fn.fn)
st.Expect(t, fn.called, true)

buf, err := ioutil.ReadAll(ctx.Request.Body)
st.Expect(t, err, nil)
st.Expect(t, ctx.Request.Method, "POST")
st.Expect(t, ctx.Request.Header.Get("Content-Type"), "")
st.Expect(t, int(ctx.Request.ContentLength), 7)
st.Expect(t, string(buf), "foo bar")

// Test context data
st.Expect(t, ctx.GetString("foo"), "bar")
st.Expect(t, ctx.GetString("bar"), "baz")
}

type handler struct {
fn context.Handler
called bool
Expand Down
2 changes: 1 addition & 1 deletion plugins/multipart/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func createForm(data FormData, ctx *c.Context) error {
}

ctx.Request.Method = setMethod(ctx)
ctx.Request.Body = ioutil.NopCloser(body)
ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(body))
ctx.Request.Header.Add("Content-Type", multipartWriter.FormDataContentType())

return nil
Expand Down

0 comments on commit b306884

Please sign in to comment.