diff --git a/README.md b/README.md index e53d686..2836c7e 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,15 @@ if err := client.Run(ctx, req, &respData); err != nil { } ``` +### File support via multipart form data + +By default, the package will send a JSON body. To enable the sending of files, you can opt to +use multipart form data instead using the `UseMultipartForm` option when you create your `Client`: + +``` +client := graphql.NewClient("https://machinebox.io/graphql", graphql.UseMultipartForm()) +``` + For more information, [read the godoc package documentation](http://godoc.org/github.com/machinebox/graphql) or the [blog post](https://blog.machinebox.io/a-graphql-client-library-for-go-5bffd0455878). ## Thanks diff --git a/graphql.go b/graphql.go index 83bce6a..9bea937 100644 --- a/graphql.go +++ b/graphql.go @@ -38,15 +38,15 @@ import ( "io" "mime/multipart" "net/http" - "net/textproto" "github.com/pkg/errors" ) // Client is a client for interacting with a GraphQL API. type Client struct { - endpoint string - httpClient *http.Client + endpoint string + httpClient *http.Client + useMultipartForm bool // Log is called with various debug information. // To log to standard out, use: @@ -84,6 +84,66 @@ func (c *Client) Run(ctx context.Context, req *Request, resp interface{}) error return ctx.Err() default: } + if len(req.files) > 0 && !c.useMultipartForm { + return errors.New("cannot send files with PostFields option") + } + if c.useMultipartForm { + return c.runWithPostFields(ctx, req, resp) + } + return c.runWithJSON(ctx, req, resp) +} + +func (c *Client) runWithJSON(ctx context.Context, req *Request, resp interface{}) error { + var requestBody bytes.Buffer + requestBodyObj := struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables"` + }{ + Query: req.q, + Variables: req.vars, + } + if err := json.NewEncoder(&requestBody).Encode(requestBodyObj); err != nil { + return errors.Wrap(err, "encode body") + } + c.logf(">> variables: %v", req.vars) + c.logf(">> query: %s", req.q) + gr := &graphResponse{ + Data: resp, + } + r, err := http.NewRequest(http.MethodPost, c.endpoint, &requestBody) + if err != nil { + return err + } + r.Header.Set("Content-Type", "application/json; charset=utf-8") + r.Header.Set("Accept", "application/json; charset=utf-8") + for key, values := range req.Header { + for _, value := range values { + r.Header.Add(key, value) + } + } + c.logf(">> headers: %v", r.Header) + r = r.WithContext(ctx) + res, err := c.httpClient.Do(r) + if err != nil { + return err + } + defer res.Body.Close() + var buf bytes.Buffer + if _, err := io.Copy(&buf, res.Body); err != nil { + return errors.Wrap(err, "reading body") + } + c.logf("<< %s", buf.String()) + if err := json.NewDecoder(&buf).Decode(&gr); err != nil { + return errors.Wrap(err, "decoding response") + } + if len(gr.Errors) > 0 { + // return first error + return gr.Errors[0] + } + return nil +} + +func (c *Client) runWithPostFields(ctx context.Context, req *Request, resp interface{}) error { var requestBody bytes.Buffer writer := multipart.NewWriter(&requestBody) if err := writer.WriteField("query", req.q); err != nil { @@ -122,7 +182,7 @@ func (c *Client) Run(ctx context.Context, req *Request, resp interface{}) error return err } r.Header.Set("Content-Type", writer.FormDataContentType()) - r.Header.Set("Accept", "application/json") + r.Header.Set("Accept", "application/json; charset=utf-8") for key, values := range req.Header { for _, value := range values { r.Header.Add(key, value) @@ -154,9 +214,17 @@ func (c *Client) Run(ctx context.Context, req *Request, resp interface{}) error // making requests. // NewClient(endpoint, WithHTTPClient(specificHTTPClient)) func WithHTTPClient(httpclient *http.Client) ClientOption { - return ClientOption(func(client *Client) { + return func(client *Client) { client.httpClient = httpclient - }) + } +} + +// UseMultipartForm uses multipart/form-data and activates support for +// files. +func UseMultipartForm() ClientOption { + return func(client *Client) { + client.useMultipartForm = true + } } // ClientOption are functions that are passed into NewClient to @@ -182,26 +250,9 @@ type Request struct { vars map[string]interface{} files []file - // Header mirrors the Header of a http.Request. It contains - // the request header fields either received - // by the server or to be sent by the client. - // - // If a server received a request with header lines, - // - // Host: example.com - // accept-encoding: gzip, deflate - // Accept-Language: en-us - // fOO: Bar - // foo: two - // - // then - // - // Header = map[string][]string{ - // "Accept-Encoding": {"gzip, deflate"}, - // "Accept-Language": {"en-us"}, - // "Foo": {"Bar", "two"}, - // } - Header Header + // Header represent any request headers that will be set + // when the request is made. + Header http.Header } // NewRequest makes a new Request with the specified string. @@ -222,6 +273,8 @@ func (req *Request) Var(key string, value interface{}) { } // File sets a file to upload. +// Files are only supported with a Client that was created with +// the UseMultipartForm option. func (req *Request) File(fieldname, filename string, r io.Reader) { req.files = append(req.files, file{ Field: fieldname, @@ -230,37 +283,6 @@ func (req *Request) File(fieldname, filename string, r io.Reader) { }) } -// A Header represents the key-value pairs in an HTTP header. -type Header map[string][]string - -// Add adds the key, value pair to the header. -// It appends to any existing values associated with key. -func (h Header) Add(key, value string) { - textproto.MIMEHeader(h).Add(key, value) -} - -// Set sets the header entries associated with key to -// the single element value. It replaces any existing -// values associated with key. -func (h Header) Set(key, value string) { - textproto.MIMEHeader(h).Set(key, value) -} - -// Get gets the first value associated with the given key. -// It is case insensitive; textproto.CanonicalMIMEHeaderKey is used -// to canonicalize the provided key. -// If there are no values associated with the key, Get returns "". -// To access multiple values of a key, or to use non-canonical keys, -// access the map directly. -func (h Header) Get(key string) string { - return textproto.MIMEHeader(h).Get(key) -} - -// Del deletes the values associated with key. -func (h Header) Del(key string) { - textproto.MIMEHeader(h).Del(key) -} - // file represents a file to upload. type file struct { Field string diff --git a/graphql_json_test.go b/graphql_json_test.go new file mode 100644 index 0000000..71f7f56 --- /dev/null +++ b/graphql_json_test.go @@ -0,0 +1,107 @@ +package graphql + +import ( + "context" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matryer/is" +) + +func TestDoJSON(t *testing.T) { + is := is.New(t) + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + is.Equal(r.Method, http.MethodPost) + b, err := ioutil.ReadAll(r.Body) + is.NoErr(err) + is.Equal(string(b), `{"query":"query {}","variables":null}`+"\n") + io.WriteString(w, `{ + "data": { + "something": "yes" + } + }`) + })) + defer srv.Close() + + ctx := context.Background() + client := NewClient(srv.URL) + + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + var responseData map[string]interface{} + err := client.Run(ctx, &Request{q: "query {}"}, &responseData) + is.NoErr(err) + is.Equal(calls, 1) // calls + is.Equal(responseData["something"], "yes") +} + +func TestQueryJSON(t *testing.T) { + is := is.New(t) + + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + b, err := ioutil.ReadAll(r.Body) + is.NoErr(err) + is.Equal(string(b), `{"query":"query {}","variables":{"username":"matryer"}}`+"\n") + _, err = io.WriteString(w, `{"data":{"value":"some data"}}`) + is.NoErr(err) + })) + defer srv.Close() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + client := NewClient(srv.URL) + + req := NewRequest("query {}") + req.Var("username", "matryer") + + // check variables + is.True(req != nil) + is.Equal(req.vars["username"], "matryer") + + var resp struct { + Value string + } + err := client.Run(ctx, req, &resp) + is.NoErr(err) + is.Equal(calls, 1) + + is.Equal(resp.Value, "some data") +} + +func TestHeader(t *testing.T) { + is := is.New(t) + + var calls int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + is.Equal(r.Header.Get("X-Custom-Header"), "123") + + _, err := io.WriteString(w, `{"data":{"value":"some data"}}`) + is.NoErr(err) + })) + defer srv.Close() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + client := NewClient(srv.URL) + + req := NewRequest("query {}") + req.Header.Set("X-Custom-Header", "123") + + var resp struct { + Value string + } + err := client.Run(ctx, req, &resp) + is.NoErr(err) + is.Equal(calls, 1) + + is.Equal(resp.Value, "some data") +} diff --git a/graphql_test.go b/graphql_multipart_test.go similarity index 91% rename from graphql_test.go rename to graphql_multipart_test.go index 01bd725..1756d95 100644 --- a/graphql_test.go +++ b/graphql_multipart_test.go @@ -27,7 +27,7 @@ func TestWithClient(t *testing.T) { } ctx := context.Background() - client := NewClient("", WithHTTPClient(testClient)) + client := NewClient("", WithHTTPClient(testClient), UseMultipartForm()) req := NewRequest(``) client.Run(ctx, req, nil) @@ -35,7 +35,7 @@ func TestWithClient(t *testing.T) { is.Equal(calls, 1) // calls } -func TestDo(t *testing.T) { +func TestDoUseMultipartForm(t *testing.T) { is := is.New(t) var calls int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -52,7 +52,7 @@ func TestDo(t *testing.T) { defer srv.Close() ctx := context.Background() - client := NewClient(srv.URL) + client := NewClient(srv.URL, UseMultipartForm()) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -80,7 +80,7 @@ func TestDoErr(t *testing.T) { defer srv.Close() ctx := context.Background() - client := NewClient(srv.URL) + client := NewClient(srv.URL, UseMultipartForm()) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -107,7 +107,7 @@ func TestDoNoResponse(t *testing.T) { defer srv.Close() ctx := context.Background() - client := NewClient(srv.URL) + client := NewClient(srv.URL, UseMultipartForm()) ctx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -132,7 +132,7 @@ func TestQuery(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - client := NewClient(srv.URL) + client := NewClient(srv.URL, UseMultipartForm()) req := NewRequest("query {}") req.Var("username", "matryer") @@ -173,7 +173,7 @@ func TestFile(t *testing.T) { defer srv.Close() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - client := NewClient(srv.URL) + client := NewClient(srv.URL, UseMultipartForm()) f := strings.NewReader(`This is a file`) req := NewRequest("query {}") req.File("file", "filename.txt", f)