-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbind_request.go
141 lines (132 loc) · 3.29 KB
/
bind_request.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package xhttp
import (
"bytes"
xconv "github.com/goclub/conv"
xerr "github.com/goclub/error"
xjson "github.com/goclub/json"
xreflect "github.com/goclub/reflect"
"github.com/gorilla/mux"
"io/ioutil"
"net/http"
"reflect"
"strings"
)
type RequestUnmarshaler interface {
UnmarshalRequest() (string, error)
}
var requestUnmarshalerType = reflect.TypeOf((*RequestUnmarshaler)(nil)).Elem()
type RequestMarshaler interface {
MarshalRequest(value string) error
}
var requestMarshalerType = reflect.TypeOf((*RequestMarshaler)(nil)).Elem()
type bindRequestEachCounter struct {
QueryCount uint
}
func BindRequest(ptr interface{}, r *http.Request) error {
// 判断ptr 必须是指针
if reflect.ValueOf(ptr).Kind() != reflect.Ptr {
return xerr.New("goclub/http: BindRequest(ptr) ptr not be pointer")
}
contentType := r.Header.Get("Content-Type")
query := r.URL.Query()
queryCount := len(query)
param := mux.Vars(r)
paramCount := len(param)
paramGet := func(key string) string {
return param[key]
}
var formCount int
// 下面的代码会重新赋值 formGet
var formGet = func(key string) string { return "" }
bindingIsOver := func() bool {
return formCount == 0 && queryCount == 0 && paramCount == 0
}
switch {
case strings.Contains(contentType, "application/x-www-form-urlencoded"):
err := r.ParseForm()
if err != nil {
return err
}
formCount = len(r.PostForm)
formGet = func(key string) string {
return r.PostForm.Get(key)
}
case strings.Contains(contentType, "multipart/form-data"):
err := r.ParseMultipartForm(32 << 20)
if err != nil {
return err
}
formCount = len(r.MultipartForm.Value)
formGet = func(key string) string {
return r.FormValue(key)
}
case strings.Contains(contentType, "application/json"):
jsonb, err := ioutil.ReadAll(r.Body)
if err != nil {
return err
}
r.Body = ioutil.NopCloser(bytes.NewBuffer(jsonb))
if len(jsonb) != 0 {
err = xjson.Unmarshal(jsonb, ptr)
if err != nil {
return xerr.WithStack(err)
}
}
default:
}
if bindingIsOver() {
return nil
}
return xreflect.DeepEach1(ptr, func(rValue reflect.Value, rType reflect.Type, field reflect.StructField) (op xreflect.EachOperator) {
if bindingIsOver() {
return op.Break()
}
/* parse param */ {
err := parserField(¶mCount, field.Tag.Get("param"), paramGet, rValue, rType)
if err != nil {
return op.Error(err)
}
}
/* parse query */ {
err := parserField(&queryCount, field.Tag.Get("query"), query.Get, rValue, rType)
if err != nil {
return op.Error(err)
}
}
/* parse form */ {
err := parserField(&formCount, field.Tag.Get("form"), formGet, rValue, rType)
if err != nil {
return op.Error(err)
}
}
return
})
}
func parserField(unresolvedCount *int, key string, get func(key string) string, rValue reflect.Value, rType reflect.Type) error {
if *unresolvedCount == 0 {
return nil
}
if key == "" {
return nil
}
value := get(key)
if value == "" {
return nil
}
/* 转换赋值 */ {
if reflect.PtrTo(rType).Implements(requestMarshalerType) {
err := rValue.Addr().Interface().(RequestMarshaler).MarshalRequest(value)
if err != nil {
return err
}
*unresolvedCount--
} else {
err := xconv.StringToReflect(value, rValue)
if err != nil {
return err
}
*unresolvedCount--
}
}
return nil
}