diff --git a/lib/http/cors_options.go b/lib/http/cors_options.go index 2d36e6d4..c1151556 100644 --- a/lib/http/cors_options.go +++ b/lib/http/cors_options.go @@ -4,6 +4,12 @@ package http +import ( + "net/http" + "strconv" + "strings" +) + // CORSOptions define optional options for server to allow other servers to // access its resources. type CORSOptions struct { @@ -36,6 +42,136 @@ type CORSOptions struct { // can be made using credentials. AllowCredentials bool - allowHeadersAll bool // flag to indicate wildcards on list. - allowOriginsAll bool // flag to indicate wildcards on list. + allowHeadersAll bool // flag to indicate wildcards on AllowHeaders. + allowOriginsAll bool // flag to indicate wildcards on AllowOrigins. +} + +// handle handle the CORS request. +// +// Reference: https://www.html5rocks.com/static/images/cors_server_flowchart.png +func (cors *CORSOptions) handle(res http.ResponseWriter, req *http.Request) { + var preflightOrigin = req.Header.Get(HeaderOrigin) + if len(preflightOrigin) == 0 { + return + } + + // Set the "Access-Control-Allow-Origin" header based on the request + // Origin and matched allowed origin. + // If one of the AllowOrigins contains wildcard "*", then allow all. + + if cors.allowOriginsAll { + res.Header().Set(HeaderACAllowOrigin, preflightOrigin) + } else { + var origin string + for _, origin = range cors.AllowOrigins { + if origin == corsWildcard { + res.Header().Set(HeaderACAllowOrigin, preflightOrigin) + break + } + if origin == preflightOrigin { + res.Header().Set(HeaderACAllowOrigin, preflightOrigin) + break + } + } + } + + // Set the "Access-Control-Allow-Method" header based on the request + // header "Access-Control-Request-Method", only allow HTTP method + // DELETE, GET, PATCH, POST, and PUT. + // If no "Access-Control-Request-Method", set the response header + // "Access-Control-Expose-Headers" based on predefined values. + + var preflightMethod = req.Header.Get(HeaderACRequestMethod) + if len(preflightMethod) == 0 { + if len(cors.exposeHeaders) > 0 { + res.Header().Set(HeaderACExposeHeaders, cors.exposeHeaders) + } + } else if preflightMethod == http.MethodDelete || + preflightMethod == http.MethodGet || + preflightMethod == http.MethodPatch || + preflightMethod == http.MethodPost || + preflightMethod == http.MethodPut { + res.Header().Set(HeaderACAllowMethod, preflightMethod) + } + + cors.handleRequestHeaders(res, req) + + if len(cors.maxAge) > 0 { + res.Header().Set(HeaderACMaxAge, cors.maxAge) + } + if cors.AllowCredentials { + res.Header().Set(HeaderACAllowCredentials, `true`) + } +} + +// handleRequestHeaders set the response header +// "Access-Control-Allow-Headers" based on the request header +// "Access-Control-Request-Headers". +// If [CORSOptions.AllowHeaders] is empty, no requested headers will be +// allowed. +// If [CORSOptions.AllowHeaders] contains wildcard "*", all requested +// headers are allowed. +func (cors *CORSOptions) handleRequestHeaders(res http.ResponseWriter, req *http.Request) { + var preflightHeaders = req.Header.Get(HeaderACRequestHeaders) + if len(preflightHeaders) == 0 { + return + } + + var ( + reqHeaders = strings.Split(preflightHeaders, `,`) + x int + ) + for x = 0; x < len(reqHeaders); x++ { + reqHeaders[x] = strings.ToLower(strings.TrimSpace(reqHeaders[x])) + } + + var ( + allowHeaders = make([]string, 0, len(reqHeaders)) + reqHeader string + allowHeader string + ) + for _, reqHeader = range reqHeaders { + if cors.allowHeadersAll { + allowHeaders = append(allowHeaders, reqHeader) + } else { + for _, allowHeader = range cors.AllowHeaders { + if reqHeader == allowHeader { + allowHeaders = append(allowHeaders, reqHeader) + break + } + } + } + } + if len(allowHeaders) == 0 { + return + } + + res.Header().Set(HeaderACAllowHeaders, strings.Join(allowHeaders, `,`)) +} + +func (cors *CORSOptions) init() { + var value string + + for _, value = range cors.AllowOrigins { + if value == corsWildcard { + cors.allowOriginsAll = true + break + } + } + + var x int + for x, value = range cors.AllowHeaders { + if value == corsWildcard { + cors.allowHeadersAll = true + } else { + cors.AllowHeaders[x] = strings.ToLower(cors.AllowHeaders[x]) + } + } + + if len(cors.ExposeHeaders) > 0 { + cors.exposeHeaders = strings.Join(cors.ExposeHeaders, `,`) + } + if cors.MaxAge > 0 { + cors.maxAge = strconv.Itoa(cors.MaxAge) + } } diff --git a/lib/http/server.go b/lib/http/server.go index b982b88c..e26fe082 100644 --- a/lib/http/server.go +++ b/lib/http/server.go @@ -277,11 +277,11 @@ func (srv *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { srv.handleDelete(res, req) case http.MethodGet: - srv.handleCORS(res, req) + srv.Options.CORS.handle(res, req) srv.handleGet(res, req) case http.MethodHead: - srv.handleCORS(res, req) + srv.Options.CORS.handle(res, req) srv.handleHead(res, req) case http.MethodOptions: @@ -291,7 +291,7 @@ func (srv *Server) ServeHTTP(res http.ResponseWriter, req *http.Request) { srv.handlePatch(res, req) case http.MethodPost: - srv.handleCORS(res, req) + srv.Options.CORS.handle(res, req) srv.handlePost(res, req) case http.MethodPut: @@ -385,98 +385,6 @@ func (srv *Server) getFSNode(reqPath string) (node *memfs.Node, isDir bool) { return node, false } -// handleCORS handle the CORS request. -// -// Reference: https://www.html5rocks.com/static/images/cors_server_flowchart.png -func (srv *Server) handleCORS(res http.ResponseWriter, req *http.Request) { - var preflightOrigin = req.Header.Get(HeaderOrigin) - if len(preflightOrigin) == 0 { - return - } - - // Set the "Access-Control-Allow-Origin" header based on the request - // Origin and matched allowed origin. - // If one of the AllowOrigins contains wildcard "*", then allow all. - - for _, origin := range srv.Options.CORS.AllowOrigins { - if origin == corsWildcard { - res.Header().Set(HeaderACAllowOrigin, preflightOrigin) - break - } - if origin == preflightOrigin { - res.Header().Set(HeaderACAllowOrigin, preflightOrigin) - break - } - } - - // Set the "Access-Control-Allow-Method" header based on the request - // header "Access-Control-Request-Method", only allow HTTP method - // DELETE, GET, PATCH, POST, and PUT. - // If no "Access-Control-Request-Method", set the response header - // "Access-Control-Expose-Headers" based on predefined values. - - var preflightMethod = req.Header.Get(HeaderACRequestMethod) - if len(preflightMethod) == 0 { - if len(srv.Options.CORS.exposeHeaders) > 0 { - res.Header().Set(HeaderACExposeHeaders, srv.Options.CORS.exposeHeaders) - } - } else if preflightMethod == http.MethodDelete || - preflightMethod == http.MethodGet || - preflightMethod == http.MethodPatch || - preflightMethod == http.MethodPost || - preflightMethod == http.MethodPut { - res.Header().Set(HeaderACAllowMethod, preflightMethod) - } - - srv.handleCORSRequestHeaders(res, req) - - if len(srv.Options.CORS.maxAge) > 0 { - res.Header().Set(HeaderACMaxAge, srv.Options.CORS.maxAge) - } - if srv.Options.CORS.AllowCredentials { - res.Header().Set(HeaderACAllowCredentials, "true") - } -} - -// handleCORSRequestHeaders set the response header -// "Access-Control-Allow-Headers" based on the request header -// "Access-Control-Request-Headers". -// If [CORSOptions.AllowHeaders] is empty, no requested headers will be -// allowed. -// If [CORSOptions.AllowHeaders] contains wildcard "*", all requested -// headers are allowed. -func (srv *Server) handleCORSRequestHeaders(res http.ResponseWriter, req *http.Request) { - preflightHeaders := req.Header.Get(HeaderACRequestHeaders) - if len(preflightHeaders) == 0 { - return - } - - reqHeaders := strings.Split(preflightHeaders, ",") - for x := 0; x < len(reqHeaders); x++ { - reqHeaders[x] = strings.ToLower(strings.TrimSpace(reqHeaders[x])) - } - - allowHeaders := make([]string, 0, len(reqHeaders)) - - for _, reqHeader := range reqHeaders { - for _, allowHeader := range srv.Options.CORS.AllowHeaders { - if allowHeader == corsWildcard { - allowHeaders = append(allowHeaders, reqHeader) - break - } - if reqHeader == allowHeader { - allowHeaders = append(allowHeaders, reqHeader) - break - } - } - } - if len(allowHeaders) == 0 { - return - } - - res.Header().Set(HeaderACAllowHeaders, strings.Join(allowHeaders, ",")) -} - // handleDelete handle the DELETE request by searching the registered route // and calling the endpoint. func (srv *Server) handleDelete(res http.ResponseWriter, req *http.Request) { @@ -748,7 +656,7 @@ func (srv *Server) handleOptions(res http.ResponseWriter, req *http.Request) { res.Header().Set(HeaderAllow, strings.Join(allows, ", ")) - srv.handleCORS(res, req) + srv.Options.CORS.handle(res, req) res.WriteHeader(http.StatusOK) } diff --git a/lib/http/serveroptions.go b/lib/http/serveroptions.go index 38da8531..5ff15fb6 100644 --- a/lib/http/serveroptions.go +++ b/lib/http/serveroptions.go @@ -8,8 +8,6 @@ import ( "io" "log" "net/http" - "strconv" - "strings" "git.sr.ht/~shulhan/pakakeh.go/lib/memfs" ) @@ -72,25 +70,5 @@ func (opts *ServerOptions) init() { } } - for x := 0; x < len(opts.CORS.AllowOrigins); x++ { - if opts.CORS.AllowOrigins[x] == corsWildcard { - opts.CORS.allowOriginsAll = true - break - } - } - - for x := 0; x < len(opts.CORS.AllowHeaders); x++ { - if opts.CORS.AllowHeaders[x] == corsWildcard { - opts.CORS.allowHeadersAll = true - } else { - opts.CORS.AllowHeaders[x] = strings.ToLower(opts.CORS.AllowHeaders[x]) - } - } - - if len(opts.CORS.ExposeHeaders) > 0 { - opts.CORS.exposeHeaders = strings.Join(opts.CORS.ExposeHeaders, ",") - } - if opts.CORS.MaxAge > 0 { - opts.CORS.maxAge = strconv.Itoa(opts.CORS.MaxAge) - } + opts.CORS.init() }