Skip to content

Commit

Permalink
lib/http: move CORS initialization and handler to cors_options
Browse files Browse the repository at this point in the history
This is to decoupling between ServerOptions and Server handlers.
  • Loading branch information
shuLhan committed Mar 5, 2024
1 parent f27e172 commit 7651077
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 121 deletions.
140 changes: 138 additions & 2 deletions lib/http/cors_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

package http

import (
"net/http"
"strconv"
"strings"
)

// CORSOptions define optional options for server to allow other servers to
// access its resources.
type CORSOptions struct {
Expand Down Expand Up @@ -36,6 +42,136 @@ type CORSOptions struct {
// can be made using credentials.
AllowCredentials bool

allowHeadersAll bool // flag to indicate wildcards on list.
allowOriginsAll bool // flag to indicate wildcards on list.
allowHeadersAll bool // flag to indicate wildcards on AllowHeaders.
allowOriginsAll bool // flag to indicate wildcards on AllowOrigins.
}

// handle handle the CORS request.
//
// Reference: https://www.html5rocks.com/static/images/cors_server_flowchart.png
func (cors *CORSOptions) handle(res http.ResponseWriter, req *http.Request) {
var preflightOrigin = req.Header.Get(HeaderOrigin)
if len(preflightOrigin) == 0 {
return
}

// Set the "Access-Control-Allow-Origin" header based on the request
// Origin and matched allowed origin.
// If one of the AllowOrigins contains wildcard "*", then allow all.

if cors.allowOriginsAll {
res.Header().Set(HeaderACAllowOrigin, preflightOrigin)
} else {
var origin string
for _, origin = range cors.AllowOrigins {
if origin == corsWildcard {
res.Header().Set(HeaderACAllowOrigin, preflightOrigin)
break
}
if origin == preflightOrigin {
res.Header().Set(HeaderACAllowOrigin, preflightOrigin)
break
}
}
}

// Set the "Access-Control-Allow-Method" header based on the request
// header "Access-Control-Request-Method", only allow HTTP method
// DELETE, GET, PATCH, POST, and PUT.
// If no "Access-Control-Request-Method", set the response header
// "Access-Control-Expose-Headers" based on predefined values.

var preflightMethod = req.Header.Get(HeaderACRequestMethod)
if len(preflightMethod) == 0 {
if len(cors.exposeHeaders) > 0 {
res.Header().Set(HeaderACExposeHeaders, cors.exposeHeaders)
}
} else if preflightMethod == http.MethodDelete ||
preflightMethod == http.MethodGet ||
preflightMethod == http.MethodPatch ||
preflightMethod == http.MethodPost ||
preflightMethod == http.MethodPut {
res.Header().Set(HeaderACAllowMethod, preflightMethod)
}

cors.handleRequestHeaders(res, req)

if len(cors.maxAge) > 0 {
res.Header().Set(HeaderACMaxAge, cors.maxAge)
}
if cors.AllowCredentials {
res.Header().Set(HeaderACAllowCredentials, `true`)
}
}

// handleRequestHeaders set the response header
// "Access-Control-Allow-Headers" based on the request header
// "Access-Control-Request-Headers".
// If [CORSOptions.AllowHeaders] is empty, no requested headers will be
// allowed.
// If [CORSOptions.AllowHeaders] contains wildcard "*", all requested
// headers are allowed.
func (cors *CORSOptions) handleRequestHeaders(res http.ResponseWriter, req *http.Request) {
var preflightHeaders = req.Header.Get(HeaderACRequestHeaders)
if len(preflightHeaders) == 0 {
return
}

var (
reqHeaders = strings.Split(preflightHeaders, `,`)
x int
)
for x = 0; x < len(reqHeaders); x++ {
reqHeaders[x] = strings.ToLower(strings.TrimSpace(reqHeaders[x]))
}

var (
allowHeaders = make([]string, 0, len(reqHeaders))
reqHeader string
allowHeader string
)
for _, reqHeader = range reqHeaders {
if cors.allowHeadersAll {
allowHeaders = append(allowHeaders, reqHeader)
} else {
for _, allowHeader = range cors.AllowHeaders {
if reqHeader == allowHeader {
allowHeaders = append(allowHeaders, reqHeader)
break
}
}
}
}
if len(allowHeaders) == 0 {
return
}

res.Header().Set(HeaderACAllowHeaders, strings.Join(allowHeaders, `,`))
}

func (cors *CORSOptions) init() {
var value string

for _, value = range cors.AllowOrigins {
if value == corsWildcard {
cors.allowOriginsAll = true
break
}
}

var x int
for x, value = range cors.AllowHeaders {
if value == corsWildcard {
cors.allowHeadersAll = true
} else {
cors.AllowHeaders[x] = strings.ToLower(cors.AllowHeaders[x])
}
}

if len(cors.ExposeHeaders) > 0 {
cors.exposeHeaders = strings.Join(cors.ExposeHeaders, `,`)
}
if cors.MaxAge > 0 {
cors.maxAge = strconv.Itoa(cors.MaxAge)
}
}
100 changes: 4 additions & 96 deletions lib/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,11 @@ func (srv *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
srv.handleDelete(res, req)

case http.MethodGet:
srv.handleCORS(res, req)
srv.Options.CORS.handle(res, req)
srv.handleGet(res, req)

case http.MethodHead:
srv.handleCORS(res, req)
srv.Options.CORS.handle(res, req)
srv.handleHead(res, req)

case http.MethodOptions:
Expand All @@ -291,7 +291,7 @@ func (srv *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) {
srv.handlePatch(res, req)

case http.MethodPost:
srv.handleCORS(res, req)
srv.Options.CORS.handle(res, req)
srv.handlePost(res, req)

case http.MethodPut:
Expand Down Expand Up @@ -385,98 +385,6 @@ func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) {
return node, false
}

// handleCORS handle the CORS request.
//
// Reference: https://www.html5rocks.com/static/images/cors_server_flowchart.png
func (srv *Server) handleCORS(res http.ResponseWriter, req *http.Request) {
var preflightOrigin = req.Header.Get(HeaderOrigin)
if len(preflightOrigin) == 0 {
return
}

// Set the "Access-Control-Allow-Origin" header based on the request
// Origin and matched allowed origin.
// If one of the AllowOrigins contains wildcard "*", then allow all.

for _, origin := range srv.Options.CORS.AllowOrigins {
if origin == corsWildcard {
res.Header().Set(HeaderACAllowOrigin, preflightOrigin)
break
}
if origin == preflightOrigin {
res.Header().Set(HeaderACAllowOrigin, preflightOrigin)
break
}
}

// Set the "Access-Control-Allow-Method" header based on the request
// header "Access-Control-Request-Method", only allow HTTP method
// DELETE, GET, PATCH, POST, and PUT.
// If no "Access-Control-Request-Method", set the response header
// "Access-Control-Expose-Headers" based on predefined values.

var preflightMethod = req.Header.Get(HeaderACRequestMethod)
if len(preflightMethod) == 0 {
if len(srv.Options.CORS.exposeHeaders) > 0 {
res.Header().Set(HeaderACExposeHeaders, srv.Options.CORS.exposeHeaders)
}
} else if preflightMethod == http.MethodDelete ||
preflightMethod == http.MethodGet ||
preflightMethod == http.MethodPatch ||
preflightMethod == http.MethodPost ||
preflightMethod == http.MethodPut {
res.Header().Set(HeaderACAllowMethod, preflightMethod)
}

srv.handleCORSRequestHeaders(res, req)

if len(srv.Options.CORS.maxAge) > 0 {
res.Header().Set(HeaderACMaxAge, srv.Options.CORS.maxAge)
}
if srv.Options.CORS.AllowCredentials {
res.Header().Set(HeaderACAllowCredentials, "true")
}
}

// handleCORSRequestHeaders set the response header
// "Access-Control-Allow-Headers" based on the request header
// "Access-Control-Request-Headers".
// If [CORSOptions.AllowHeaders] is empty, no requested headers will be
// allowed.
// If [CORSOptions.AllowHeaders] contains wildcard "*", all requested
// headers are allowed.
func (srv *Server) handleCORSRequestHeaders(res http.ResponseWriter, req *http.Request) {
preflightHeaders := req.Header.Get(HeaderACRequestHeaders)
if len(preflightHeaders) == 0 {
return
}

reqHeaders := strings.Split(preflightHeaders, ",")
for x := 0; x < len(reqHeaders); x++ {
reqHeaders[x] = strings.ToLower(strings.TrimSpace(reqHeaders[x]))
}

allowHeaders := make([]string, 0, len(reqHeaders))

for _, reqHeader := range reqHeaders {
for _, allowHeader := range srv.Options.CORS.AllowHeaders {
if allowHeader == corsWildcard {
allowHeaders = append(allowHeaders, reqHeader)
break
}
if reqHeader == allowHeader {
allowHeaders = append(allowHeaders, reqHeader)
break
}
}
}
if len(allowHeaders) == 0 {
return
}

res.Header().Set(HeaderACAllowHeaders, strings.Join(allowHeaders, ","))
}

// handleDelete handle the DELETE request by searching the registered route
// and calling the endpoint.
func (srv *Server) handleDelete(res http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -748,7 +656,7 @@ func (srv *Server) handleOptions(res http.ResponseWriter, req *http.Request) {

res.Header().Set(HeaderAllow, strings.Join(allows, ", "))

srv.handleCORS(res, req)
srv.Options.CORS.handle(res, req)

res.WriteHeader(http.StatusOK)
}
Expand Down
24 changes: 1 addition & 23 deletions lib/http/serveroptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"io"
"log"
"net/http"
"strconv"
"strings"

"git.sr.ht/~shulhan/pakakeh.go/lib/memfs"
)
Expand Down Expand Up @@ -72,25 +70,5 @@ func (opts *ServerOptions) init() {
}
}

for x := 0; x < len(opts.CORS.AllowOrigins); x++ {
if opts.CORS.AllowOrigins[x] == corsWildcard {
opts.CORS.allowOriginsAll = true
break
}
}

for x := 0; x < len(opts.CORS.AllowHeaders); x++ {
if opts.CORS.AllowHeaders[x] == corsWildcard {
opts.CORS.allowHeadersAll = true
} else {
opts.CORS.AllowHeaders[x] = strings.ToLower(opts.CORS.AllowHeaders[x])
}
}

if len(opts.CORS.ExposeHeaders) > 0 {
opts.CORS.exposeHeaders = strings.Join(opts.CORS.ExposeHeaders, ",")
}
if opts.CORS.MaxAge > 0 {
opts.CORS.maxAge = strconv.Itoa(opts.CORS.MaxAge)
}
opts.CORS.init()
}

0 comments on commit 7651077

Please sign in to comment.