diff --git a/binding.go b/binding.go index 94e2add..510bf3f 100644 --- a/binding.go +++ b/binding.go @@ -418,3 +418,65 @@ func MultipartForm(model interface{}, opts ...Options) flamego.Handler { } }) } + +func mapQuery(obj reflect.Value, values url.Values, errs Errors) Errors { + for i := 0; i < obj.Elem().NumField(); i++ { + f := obj.Elem().Field(i) + t := obj.Elem().Type().Field(i) + + if f.Kind() == reflect.Struct && t.Anonymous { + errs = mapQuery(f.Addr(), values, errs) + continue + } + + tag := t.Tag.Get("query") + if tag == "-" { + continue + } + + if tag == "" { + tag = t.Name + } + + if !f.CanSet() { + continue + } + + vals := values[tag] + if len(vals) == 0 { + continue + } + + err := setWithProperType(f.Kind(), vals[0], f, tag) + if err != nil { + errs = append(errs, *err) + } + } + + return errs +} + +func Query(model interface{}, opts ...Options) flamego.Handler { + ensureNotPointer(model) + + var opt Options + if len(opts) > 0 { + opt = opts[0] + } + opt = parseOptions(opt) + + return flamego.ContextInvoker(func(c flamego.Context) { + var errs Errors + r := c.Request().Request + obj := reflect.New(reflect.TypeOf(model)) + errs = mapQuery(obj, r.URL.Query(), errs) + validateAndMap(c, opt.Validator, obj, errs) + errs = c.Value(reflect.TypeOf(errs)).Interface().(Errors) + if len(errs) > 0 && opt.ErrorHandler != nil { + _, err := c.Invoke(opt.ErrorHandler) + if err != nil { + panic("binding.Query: " + err.Error()) + } + } + }) +} diff --git a/binding_test.go b/binding_test.go index 3b92134..71075b4 100644 --- a/binding_test.go +++ b/binding_test.go @@ -12,6 +12,7 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -540,7 +541,7 @@ male: true email: logan.smith@example.com weight: 60.7 balance: -12.4 -address: +address: street: 404 Broadway city: Browser planet: Internet @@ -580,7 +581,7 @@ male: true email: logan.smith@example.com weight: 60.7 balance: -12.4 -address: +address: street: 404 Broadway city: Browser planet: Internet @@ -620,7 +621,7 @@ male: bad email: logan.smith@example.com weight: 60.7 balance: -12.4 -address: +address: street: 404 Broadway city: Browser planet: Internet @@ -735,10 +736,10 @@ address: first_name: Logan last_name: Smith age: -height: -male: +height: +male: email: logan.smith@example.com -weight: +weight: balance: address: street: 404 Broadway @@ -776,7 +777,7 @@ male: on email: logan.smith@example.com weight: 60.7 balance: -12.4 -address: +address: street: 404 Broadway city: Browser planet: Internet @@ -1329,3 +1330,108 @@ func TestMultipartForm(t *testing.T) { assert.NotNil(t, gotForm.Background) assert.Len(t, gotForm.Pictures, 2) } + +func TestQuery(t *testing.T) { + t.Run("pointer model", func(t *testing.T) { + assert.PanicsWithValue(t, + "binding: pointer can not be accepted as binding model", + func() { + type form struct { + Username string + Password string + } + Query(&form{}) + }, + ) + }) + + t.Run("custom error handler", func(t *testing.T) { + type query struct { + Username string `query:"username" validate:"required"` + Password string `query:"password" validate:"required"` + } + + normalHandler := func(rw http.ResponseWriter, errs Errors) { + rw.WriteHeader(http.StatusBadRequest) + _, _ = rw.Write([]byte(errs[0].Err.Error())) + } + + fastInvokerHandler := func(c flamego.Context, errs Errors) { + c.ResponseWriter().WriteHeader(http.StatusBadRequest) + _, _ = c.ResponseWriter().Write([]byte(fmt.Sprintf("Oops! Error occurred: %v", errs[0].Err))) + } + + tests := []struct { + name string + fields map[string]string + handler flamego.Handler + statusCode int + want string + }{ + { + name: "validation error", + fields: map[string]string{ + "username": "alice", + }, + handler: fastInvokerHandler, + statusCode: http.StatusBadRequest, + want: `Oops! Error occurred: Key: "query.Password" Error: Field validation for "Password" failed on the "required" tag`, + }, + { + name: "normal handler", + fields: map[string]string{ + "username": "alice", + }, + handler: normalHandler, + statusCode: http.StatusBadRequest, + want: `Key: "query.Password" Error: Field validation for "Password" failed on the "required" tag`, + }, + { + name: "fast invoker handler", + fields: map[string]string{ + "username": "alice", + "password": "superSecurePassword", + }, + handler: fastInvokerHandler, + statusCode: http.StatusOK, + want: "Hello world", + }, + { + name: "nil handler", + fields: map[string]string{ + "username": "alice", + }, + handler: nil, + statusCode: http.StatusOK, + want: "Hello world", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + f := flamego.New() + opts := Options{ + ErrorHandler: test.handler, + } + f.Get("/", Query(query{}, opts), func(c flamego.Context) { + _, _ = c.ResponseWriter().Write([]byte("Hello world")) + }) + + u := url.URL{Path: "/"} + q := u.Query() + for k, v := range test.fields { + q.Set(k, v) + } + u.RawQuery = q.Encode() + + var body bytes.Buffer + resp := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, u.String(), &body) + assert.Nil(t, err) + + f.ServeHTTP(resp, req) + assert.Equal(t, test.statusCode, resp.Code) + assert.Equal(t, test.want, resp.Body.String()) + }) + } + }) +}