Skip to content

Commit

Permalink
refractor: response
Browse files Browse the repository at this point in the history
  • Loading branch information
imstevez committed Sep 4, 2023
1 parent ac30bbe commit 2fb46c1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 101 deletions.
168 changes: 88 additions & 80 deletions s3/responses/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ const (

var errValueNotSet = fmt.Errorf("value not set")

var byteSliceType = reflect.TypeOf([]byte{})

func WriteResponse(w http.ResponseWriter, statusCode int, output interface{}, locationName string) (err error) {
if locationName != "" {
output = wrapOutput(output, locationName)
setCommonHeaders(w.Header())

outv := reflect.Indirect(reflect.ValueOf(wrapOutput(output, locationName)))
if !outv.IsValid() {
w.WriteHeader(statusCode)
return
}

defer func() {
Expand All @@ -46,32 +48,31 @@ func WriteResponse(w http.ResponseWriter, statusCode int, output interface{}, lo
}
}()

if !valid(output) {
w.WriteHeader(statusCode)
err = setFieldRequestID(w.Header(), outv)
if err != nil {
return
}

body, clen, ctyp, err := extractBody(output)
body, clen, ctyp, err := extractBody(outv)
if err != nil {
return
}

if body == nil {
err = extractHeaders(w.Header(), output)
err = setLocationHeaders(w.Header(), outv)
if err != nil {
return
}
w.WriteHeader(statusCode)
return
}

defer func() {
_ = body.Close()
}()
defer body.Close()

w.Header().Set(consts.ContentLength, fmt.Sprintf("%d", clen))
w.Header().Set(consts.ContentType, ctyp)

err = extractHeaders(w.Header(), output)
err = setLocationHeaders(w.Header(), outv)
if err != nil {
return
}
Expand All @@ -84,6 +85,11 @@ func WriteResponse(w http.ResponseWriter, statusCode int, output interface{}, lo
}

func wrapOutput(v interface{}, locationName string) (wrapper interface{}) {
if locationName == "" {
wrapper = v
return
}

outputTag := fmt.Sprintf(`locationName:"%s" type:"structure"`, locationName)
fields := []reflect.StructField{
{
Expand All @@ -105,16 +111,16 @@ func wrapOutput(v interface{}, locationName string) (wrapper interface{}) {
return
}

func extractBody(output interface{}) (body io.ReadCloser, clen int, ctyp string, err error) {
ptyp, plod := getPayload(output)
func extractBody(v reflect.Value) (body io.ReadCloser, clen int, ctyp string, err error) {
ptyp, plod := getPayload(v)
if ptyp == noPayload {
return
}

if ptyp == "structure" || ptyp == "" {
var buf bytes.Buffer
buf.WriteString(xml.Header)
err = xmlutil.BuildXML(output, xml.NewEncoder(&buf))
err = xmlutil.BuildXML(v.Interface(), xml.NewEncoder(&buf))
if err != nil {
return
}
Expand All @@ -132,6 +138,15 @@ func extractBody(output interface{}) (body io.ReadCloser, clen int, ctyp string,
case io.ReadCloser:
body = pifc
clen = -1
case io.ReadSeeker:
var bs []byte
bs, err = io.ReadAll(pifc)
if err != nil {
return
}
body = io.NopCloser(bytes.NewBuffer(bs))
clen = len(bs)
ctyp = http.DetectContentType(bs)
case []byte:
body = io.NopCloser(bytes.NewBuffer(pifc))
clen = len(pifc)
Expand All @@ -148,45 +163,44 @@ func extractBody(output interface{}) (body io.ReadCloser, clen int, ctyp string,
return
}

func getPayload(output interface{}) (ptyp string, plod reflect.Value) {
ptyp = noPayload
v := reflect.Indirect(reflect.ValueOf(output))
if !v.IsValid() {
return
}
field, ok := v.Type().FieldByName("_")
if !ok {
return
}
noPayloadValue := field.Tag.Get(noPayload)
if noPayloadValue != "" {
return
}
payloadName := field.Tag.Get("payload")
if payloadName == "" {
func setFieldRequestID(headers http.Header, outv reflect.Value) (err error) {
reqId := headers.Get(consts.AmzRequestID)

idv := outv.FieldByName("RequestID")
if !idv.IsValid() {
return
}
member, ok := v.Type().FieldByName(payloadName)
if !ok {
return

switch idv.Interface().(type) {
case *string:
idv.Set(reflect.ValueOf(&reqId))
case string:
idv.Set(reflect.ValueOf(reqId))
default:
err = errValueNotSet
}
ptyp = member.Tag.Get("type")
plod = reflect.Indirect(v.FieldByName(payloadName))

return
}

func extractHeaders(header http.Header, output interface{}) (err error) {
v := reflect.ValueOf(output).Elem()
func setCommonHeaders(headers http.Header) {
reqId := getRequestID()
headers.Set(consts.ServerInfo, consts.DefaultServerInfo)
headers.Set(consts.AcceptRanges, "bytes")
headers.Set(consts.AmzRequestID, reqId)
}

func setLocationHeaders(header http.Header, v reflect.Value) (err error) {
for i := 0; i < v.NumField(); i++ {
ft := v.Type().Field(i)
fv := v.Field(i)
ft := v.Type().Field(i)
fk := fv.Kind()

if !fv.IsValid() {
if n := ft.Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}

if n := ft.Name; n[0:1] == strings.ToLower(n[0:1]) {
if !fv.IsValid() {
continue
}

Expand All @@ -202,20 +216,12 @@ func extractHeaders(header http.Header, output interface{}) (err error) {
}
}

if ft.Tag.Get("ignore") != "" {
continue
}

if ft.Tag.Get("marshal-as") == "blob" {
fv = fv.Convert(byteSliceType)
}

switch ft.Tag.Get("location") {
case "header":
name := ifemp(ft.Tag.Get("locationName"), ft.Name)
err = writeHeader(&header, fv, name, ft.Tag)
err = setHeaders(&header, fv, name, ft.Tag)
case "headers":
err = writeHeaderMap(&header, fv, ft.Tag)
err = setHeadersMap(&header, fv, ft.Tag)
}

if err != nil {
Expand All @@ -226,7 +232,7 @@ func extractHeaders(header http.Header, output interface{}) (err error) {
return
}

func writeHeader(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) (err error) {
func setHeaders(header *http.Header, v reflect.Value, name string, tag reflect.StructTag) (err error) {
str, err := convertType(v, tag)
if errors.Is(err, errValueNotSet) {
err = nil
Expand All @@ -241,7 +247,7 @@ func writeHeader(header *http.Header, v reflect.Value, name string, tag reflect.
return
}

func writeHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) (err error) {
func setHeadersMap(header *http.Header, v reflect.Value, tag reflect.StructTag) (err error) {
prefix := tag.Get("locationName")
for _, key := range v.MapKeys() {
var str string
Expand All @@ -260,6 +266,35 @@ func writeHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag)
return
}

func getPayload(v reflect.Value) (ptyp string, plod reflect.Value) {
ptyp = noPayload

field, ok := v.Type().FieldByName("_")
if !ok {
return
}

noPayloadValue := field.Tag.Get(noPayload)
if noPayloadValue != "" {
return
}

payloadName := field.Tag.Get("payload")
if payloadName == "" {
return
}

member, ok := v.Type().FieldByName(payloadName)
if !ok {
return
}

ptyp = member.Tag.Get("type")
plod = reflect.Indirect(v.FieldByName(payloadName))

return
}

func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error) {
v = reflect.Indirect(v)
if !v.IsValid() {
Expand All @@ -273,29 +308,6 @@ func convertType(v reflect.Value, tag reflect.StructTag) (str string, err error)
value = base64.StdEncoding.EncodeToString([]byte(value))
}
str = value
case []*string:
if tag.Get("location") != "header" || tag.Get("enum") == "" {
return "", fmt.Errorf("%T is only supported with location header and enum shapes", value)
}
if len(value) == 0 {
return "", errValueNotSet
}

buff := &bytes.Buffer{}
for i, sv := range value {
if sv == nil || len(*sv) == 0 {
continue
}
if i != 0 {
buff.WriteRune(',')
}
item := *sv
if strings.Index(item, `,`) != -1 || strings.Index(item, `"`) != -1 {
item = strconv.Quote(item)
}
buff.WriteString(item)
}
str = string(buff.Bytes())
case []byte:
str = base64.StdEncoding.EncodeToString(value)
case bool:
Expand Down Expand Up @@ -344,7 +356,3 @@ func ifemp(a, b string) string {
}
return b
}

func valid(ifce interface{}) bool {
return reflect.Indirect(reflect.ValueOf(ifce)).IsValid()
}
28 changes: 9 additions & 19 deletions s3/responses/responses_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,6 @@ func getRequestID() string {
return fmt.Sprintf("%d", time.Now().UnixNano())
}

func setCommonHeader(w http.ResponseWriter, requestId string) {
w.Header().Set(consts.ServerInfo, consts.DefaultServerInfo)
w.Header().Set(consts.AmzRequestID, requestId)
w.Header().Set(consts.AcceptRanges, "bytes")
}

type ErrorOutput struct {
_ struct{} `type:"structure"`
Expand All @@ -40,27 +35,22 @@ type ErrorOutput struct {
RequestID string `locationName:"RequestID" type:"string"`
}

func WriteErrorResponse(w http.ResponseWriter, r *http.Request, rerr *Error) {
reqID := getRequestID()
setCommonHeader(w, reqID)
output := &ErrorOutput{
func NewErrOutput(r *http.Request, rerr *Error) *ErrorOutput {
return &ErrorOutput{
Code: rerr.Code(),
Message: rerr.Description(),
Resource: pathClean(r.URL.Path),
RequestID: reqID,
}
err := WriteResponse(w, rerr.HTTPStatusCode(), output, "Error")
if err != nil {
fmt.Println("write response: ", err)
RequestID: "", // this field value will be automatically filled
}
}

func WriteErrorResponse(w http.ResponseWriter, r *http.Request, rerr *Error) {
output := NewErrOutput(r, rerr)
_ = WriteResponse(w, rerr.HTTPStatusCode(), output, "Error")
}

func WriteSuccessResponse(w http.ResponseWriter, output interface{}, locationName string) {
setCommonHeader(w, getRequestID())
err := WriteResponse(w, http.StatusOK, output, locationName)
if err != nil {
fmt.Println("write response: ", err)
}
_ = WriteResponse(w, http.StatusOK, output, locationName)
}

func setPutObjHeaders(w http.ResponseWriter, etag, cid string, delete bool) {
Expand Down
2 changes: 1 addition & 1 deletion s3/routers/routers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (routers *Routers) Register() http.Handler {
hs.Sign,
)

bucket := root.PathPrefix("/{bucket}").Subrouter()
bucket := root.PathPrefix("/{Bucket}").Subrouter()

// multipart object...
// CreateMultipart
Expand Down
2 changes: 1 addition & 1 deletion s3/server/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package server

const defaultServerAddress = "127.0.0.1:15001"
const defaultServerAddress = "127.0.0.1:6001"

type Option func(*Server)

Expand Down

0 comments on commit 2fb46c1

Please sign in to comment.