From b306884343eef4e87dd7011f692f33a748a8ccaa Mon Sep 17 00:00:00 2001 From: Tomas Aparicio Date: Fri, 17 Mar 2017 00:22:14 +0000 Subject: [PATCH] fix(#23): persist context data across body updates --- _examples/context/context.go | 47 ++++++++++++++++++++++++++++++++++ context/context.go | 30 +++++++++++++++++++--- plugins/body/body.go | 8 +++--- plugins/body/body_test.go | 30 ++++++++++++++++++++-- plugins/multipart/multipart.go | 2 +- 5 files changed, 106 insertions(+), 11 deletions(-) create mode 100644 _examples/context/context.go diff --git a/_examples/context/context.go b/_examples/context/context.go new file mode 100644 index 0000000..bebd83b --- /dev/null +++ b/_examples/context/context.go @@ -0,0 +1,47 @@ +package main + +import ( + "fmt" + + "gopkg.in/h2non/gentleman.v1" +) + +func main() { + // Create new request instance + req := gentleman.NewRequest() + req.Method("GET") + + // Define target URL + req.URL("http://httpbin.org/headers") + + // Set a new header field + req.SetHeader("Client", "gentleman") + + // Set sample context data + req.Context.Set("foo", "bar") + req.Context.Set("bar", "baz") + + // Set sample body as string + req.BodyString("hello, gentleman!") + + // Output all context data + fmt.Println(req.Context.GetAll()) + + // Perform the request + res, err := req.Do() + if err != nil { + fmt.Printf("Request error: %s\n", err) + return + } + if !res.Ok { + fmt.Printf("Invalid server response: %d\n", res.StatusCode) + return + } + + // Set sample context data + fmt.Println(req.Context.GetString("foo")) + fmt.Println(req.Context.GetString("bar")) + + // Reads the whole body and returns it as string + fmt.Printf("Body: %s", res.String()) +} diff --git a/context/context.go b/context/context.go index 7840fdf..570f89b 100644 --- a/context/context.go +++ b/context/context.go @@ -146,12 +146,23 @@ func (c *Context) Copy(req *http.Request) { req.Body = getContextReadCloser(c.Request).Clone() } +// WrapBody wraps and copies the current context data to a new request body. +// Since context metadata is "overloaded" as part of the request body reader +// calling this method is mandatory for any new request body. +func (c *Context) WrapBody(body io.ReadCloser) ReadCloser { + ctx := getContextReadCloser(c.Request).Context() + newBody := wrapContextReadCloser(body) + newBody.SetContext(ctx) + return newBody +} + // ReadCloser augments the io.ReadCloser interface // with a Context() method type ReadCloser interface { io.ReadCloser Clone() ReadCloser Context() map[interface{}]interface{} + SetContext(map[interface{}]interface{}) } type contextReadCloser struct { @@ -174,18 +185,29 @@ func (crc *contextReadCloser) Clone() ReadCloser { return clone } -func getContextReadCloser(req *http.Request) ReadCloser { - crc, ok := req.Body.(ReadCloser) +func (crc *contextReadCloser) SetContext(store map[interface{}]interface{}) { + for key, value := range store { + crc.store[key] = value + } +} + +func wrapContextReadCloser(body io.ReadCloser) ReadCloser { + crc, ok := body.(ReadCloser) if !ok { crc = &contextReadCloser{ - ReadCloser: req.Body, + ReadCloser: body, store: make(map[interface{}]interface{}), } - req.Body = crc } return crc } +func getContextReadCloser(req *http.Request) ReadCloser { + body := wrapContextReadCloser(req.Body) + req.Body = body + return body +} + func createRequest() *http.Request { req := &http.Request{ Method: "GET", diff --git a/plugins/body/body.go b/plugins/body/body.go index f2ef9b2..be719d4 100644 --- a/plugins/body/body.go +++ b/plugins/body/body.go @@ -17,7 +17,7 @@ import ( func String(data string) p.Plugin { return p.NewRequestPlugin(func(ctx *c.Context, h c.Handler) { ctx.Request.Method = getMethod(ctx) - ctx.Request.Body = utils.StringReader(data) + ctx.Request.Body = ctx.WrapBody(utils.StringReader(data)) ctx.Request.ContentLength = int64(bytes.NewBufferString(data).Len()) h.Next(ctx) }) @@ -42,7 +42,7 @@ func JSON(data interface{}) p.Plugin { } ctx.Request.Method = getMethod(ctx) - ctx.Request.Body = ioutil.NopCloser(buf) + ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(buf)) ctx.Request.ContentLength = int64(buf.Len()) ctx.Request.Header.Set("Content-Type", "application/json") @@ -69,7 +69,7 @@ func XML(data interface{}) p.Plugin { } ctx.Request.Method = getMethod(ctx) - ctx.Request.Body = ioutil.NopCloser(buf) + ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(buf)) ctx.Request.ContentLength = int64(buf.Len()) ctx.Request.Header.Set("Content-Type", "application/xml") @@ -98,7 +98,7 @@ func Reader(body io.Reader) p.Plugin { } } - req.Body = rc + req.Body = ctx.WrapBody(rc) ctx.Request.Method = getMethod(ctx) h.Next(ctx) diff --git a/plugins/body/body_test.go b/plugins/body/body_test.go index 95f281c..c15c9c6 100644 --- a/plugins/body/body_test.go +++ b/plugins/body/body_test.go @@ -2,10 +2,11 @@ package body import ( "bytes" - "github.com/nbio/st" - "gopkg.in/h2non/gentleman.v1/context" "io/ioutil" "testing" + + "github.com/nbio/st" + "gopkg.in/h2non/gentleman.v1/context" ) func TestBodyJSONEncodeMap(t *testing.T) { @@ -122,6 +123,31 @@ func TestBodyReader(t *testing.T) { st.Expect(t, string(buf), "foo bar") } +func TestBodyReaderContextDataSharing(t *testing.T) { + ctx := context.New() + ctx.Request.Method = "POST" + fn := newHandler() + + // Set sample context data + ctx.Set("foo", "bar") + ctx.Set("bar", "baz") + + reader := bytes.NewReader([]byte("foo bar")) + Reader(reader).Exec("request", ctx, fn.fn) + st.Expect(t, fn.called, true) + + buf, err := ioutil.ReadAll(ctx.Request.Body) + st.Expect(t, err, nil) + st.Expect(t, ctx.Request.Method, "POST") + st.Expect(t, ctx.Request.Header.Get("Content-Type"), "") + st.Expect(t, int(ctx.Request.ContentLength), 7) + st.Expect(t, string(buf), "foo bar") + + // Test context data + st.Expect(t, ctx.GetString("foo"), "bar") + st.Expect(t, ctx.GetString("bar"), "baz") +} + type handler struct { fn context.Handler called bool diff --git a/plugins/multipart/multipart.go b/plugins/multipart/multipart.go index 45863de..77d4340 100644 --- a/plugins/multipart/multipart.go +++ b/plugins/multipart/multipart.go @@ -94,7 +94,7 @@ func createForm(data FormData, ctx *c.Context) error { } ctx.Request.Method = setMethod(ctx) - ctx.Request.Body = ioutil.NopCloser(body) + ctx.Request.Body = ctx.WrapBody(ioutil.NopCloser(body)) ctx.Request.Header.Add("Content-Type", multipartWriter.FormDataContentType()) return nil