diff --git a/.travis.yml b/.travis.yml index 8bc89b8..e696075 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: go go: - - 1.6.3 + - 1.7 env: - "PATH=/home/travis/gopath/bin:$PATH" before_install: diff --git a/common.go b/common.go index e4c729d..ae16ffc 100644 --- a/common.go +++ b/common.go @@ -1,3 +1,5 @@ +// +build go1.7 + // Copyright 2015 Husobee Associates, LLC. All rights reserved. // Use of this source code is governed by The MIT License, which // can be found in the LICENSE file included. @@ -5,9 +7,8 @@ package vestigo import ( + "context" "net/http" - "net/url" - "strings" ) // methods - a list of methods that are allowed @@ -29,30 +30,34 @@ var AllowTrace = false // Param - Get a url parameter by name func Param(r *http.Request, name string) string { - return r.URL.Query().Get(":" + name) + // use the request context + if v, ok := r.Context().Value("vestigo_" + name).(string); ok { + return v + } + return "" } // ParamNames - Get a url parameter name list func ParamNames(r *http.Request) []string { - var names []string - for k := range r.URL.Query() { - if strings.HasPrefix(k, ":") { - names = append(names, k) - } + if v, ok := r.Context().Value("vestigo_param_names").([]string); ok { + return v } - return names + return []string{} } // AddParam - Add a vestigo-style parameter to the request -- useful for middleware // Appends :name=value onto a blank request query string or appends &:name=value // onto a non-blank request query string func AddParam(r *http.Request, name, value string) { - q := url.QueryEscape(":"+name) + "=" + url.QueryEscape(value) - if r.URL.RawQuery != "" { - r.URL.RawQuery += "&" + q - } else { - r.URL.RawQuery += q + paramNames := []string{name} + if v, ok := r.Context().Value("vestigo_param_names").([]string); ok { + for _, vv := range v { + paramNames = append(paramNames, vv) + } } + ctx := context.WithValue(r.Context(), "vestigo_"+name, value) + ctx = context.WithValue(ctx, "vestigo_param_names", paramNames) + *r = *r.WithContext(ctx) } //validMethod - validate that the http method is valid. diff --git a/common_legacy.go b/common_legacy.go new file mode 100644 index 0000000..ff16dc9 --- /dev/null +++ b/common_legacy.go @@ -0,0 +1,64 @@ +// +build !go1.7 + +// Copyright 2015 Husobee Associates, LLC. All rights reserved. +// Use of this source code is governed by The MIT License, which +// can be found in the LICENSE file included. + +package vestigo + +import ( + "net/http" + "net/url" + "strings" +) + +// methods - a list of methods that are allowed +var methods = map[string]bool{ + http.MethodConnect: true, + http.MethodDelete: true, + http.MethodGet: true, + http.MethodHead: true, + http.MethodOptions: true, + http.MethodPatch: true, + http.MethodPost: true, + http.MethodPut: true, + http.MethodTrace: true, +} + +// AllowTrace - Globally allow the TRACE method handling within vestigo url router. This +// generally not a good idea to have true in production settings, but excellent for testing. +var AllowTrace = false + +// Param - Get a url parameter by name +func Param(r *http.Request, name string) string { + return r.URL.Query().Get(":" + name) +} + +// ParamNames - Get a url parameter name list +func ParamNames(r *http.Request) []string { + var names []string + for k := range r.URL.Query() { + if strings.HasPrefix(k, ":") { + names = append(names, k) + } + } + return names +} + +// AddParam - Add a vestigo-style parameter to the request -- useful for middleware +// Appends :name=value onto a blank request query string or appends &:name=value +// onto a non-blank request query string +func AddParam(r *http.Request, name, value string) { + q := url.QueryEscape(":"+name) + "=" + url.QueryEscape(value) + if r.URL.RawQuery != "" { + r.URL.RawQuery += "&" + q + } else { + r.URL.RawQuery += q + } +} + +//validMethod - validate that the http method is valid. +func validMethod(method string) bool { + _, ok := methods[method] + return ok +} diff --git a/router_go17_test.go b/router_go17_test.go new file mode 100644 index 0000000..2bde4be --- /dev/null +++ b/router_go17_test.go @@ -0,0 +1,37 @@ +//+build go1.7 +package vestigo + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetParamNotExists(t *testing.T) { + r, _ := http.NewRequest("GET", "/test?group=2", nil) + // shouldnt exist + val := Param(r, "location") + assert.Equal(t, "", val) +} + +func TestParamNames(t *testing.T) { + r, _ := http.NewRequest("GET", "/test?group=2", nil) + AddParam(r, "user", "test") + AddParam(r, "location", "San Francisco, CA") + actual := ParamNames(r) + + var foundLocation bool + var foundUser bool + for _, v := range actual { + if v == "user" { + foundUser = true + } + if v == "location" { + foundLocation = true + } + } + + assert.Equal(t, foundUser, true) + assert.Equal(t, foundLocation, true) +} diff --git a/router_legacy_test.go b/router_legacy_test.go new file mode 100644 index 0000000..ed9ab22 --- /dev/null +++ b/router_legacy_test.go @@ -0,0 +1,72 @@ +//+build !go1.7 + +package vestigo + +import ( + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddParamEncode(t *testing.T) { + r, _ := http.NewRequest("GET", "/test?:user=1", nil) + AddParam(r, "id", "2 2") + assert.Equal(t, r.URL.RawQuery, ":user=1&%3Aid=2+2") +} + +func TestParamNames(t *testing.T) { + r, _ := http.NewRequest("GET", "/test?:user=1&group=2", nil) + AddParam(r, "location", "San Francisco, CA") + actual := ParamNames(r) + + var foundLocation bool + var foundUser bool + for _, v := range actual { + if v == ":user" { + foundUser = true + } + if v == ":location" { + foundLocation = true + } + } + + assert.Equal(t, foundUser, true) + assert.Equal(t, foundLocation, true) +} + +func TestRouterParamGet(t *testing.T) { + r := NewRouter() + r.Add("GET", "/users/:uid", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "222", r.URL.Query().Get(":uid")) + assert.Equal(t, "222", Param(r, "uid")) + assert.Equal(t, "red", r.URL.Query().Get("color")) + assert.Equal(t, "burger", r.URL.Query().Get("food")) + }) + + req, _ := http.NewRequest("GET", "/users/222?color=red&food=burger", nil) + h := httptest.NewRecorder() + r.ServeHTTP(h, req) +} + +func TestRouterParamPost(t *testing.T) { + r := NewRouter() + r.Add("POST", "/users/:uid", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "123", r.FormValue("id")) + assert.Equal(t, "123", r.Form.Get("id")) + assert.Equal(t, "222", r.URL.Query().Get(":uid")) + assert.Equal(t, "222", Param(r, "uid")) + assert.Equal(t, "red", r.URL.Query().Get("color")) + assert.Equal(t, "burger", r.URL.Query().Get("food")) + }) + + form := url.Values{} + form.Add("id", "123") + req, _ := http.NewRequest("POST", "/users/222?color=red&food=burger", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + h := httptest.NewRecorder() + r.ServeHTTP(h, req) +} diff --git a/router_test.go b/router_test.go index f652c85..5281e63 100644 --- a/router_test.go +++ b/router_test.go @@ -316,32 +316,6 @@ func TestRouterTwoParam(t *testing.T) { } } -func TestAddParamEncode(t *testing.T) { - r, _ := http.NewRequest("GET", "/test?:user=1", nil) - AddParam(r, "id", "2 2") - assert.Equal(t, r.URL.RawQuery, ":user=1&%3Aid=2+2") -} - -func TestParamNames(t *testing.T) { - r, _ := http.NewRequest("GET", "/test?:user=1&group=2", nil) - AddParam(r, "location", "San Francisco, CA") - actual := ParamNames(r) - - var foundLocation bool - var foundUser bool - for _, v := range actual { - if v == ":user" { - foundUser = true - } - if v == ":location" { - foundLocation = true - } - } - - assert.Equal(t, foundUser, true) - assert.Equal(t, foundLocation, true) -} - /* func TestRouterMatchAny(t *testing.T) { r := NewRouter() @@ -567,7 +541,6 @@ func TestRouterParamNames(t *testing.T) { func TestRouterParamGet(t *testing.T) { r := NewRouter() r.Add("GET", "/users/:uid", func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "222", r.URL.Query().Get(":uid")) assert.Equal(t, "222", Param(r, "uid")) assert.Equal(t, "red", r.URL.Query().Get("color")) assert.Equal(t, "burger", r.URL.Query().Get("food")) @@ -583,7 +556,6 @@ func TestRouterParamPost(t *testing.T) { r.Add("POST", "/users/:uid", func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "123", r.FormValue("id")) assert.Equal(t, "123", r.Form.Get("id")) - assert.Equal(t, "222", r.URL.Query().Get(":uid")) assert.Equal(t, "222", Param(r, "uid")) assert.Equal(t, "red", r.URL.Query().Get("color")) assert.Equal(t, "burger", r.URL.Query().Get("food"))