Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates to use context from the request for params instead of passing… #41

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: go
go:
- 1.6.3
- 1.7
env:
- "PATH=/home/travis/gopath/bin:$PATH"
before_install:
Expand Down
33 changes: 19 additions & 14 deletions common.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// +build go1.7

// Copyright 2015 Husobee Associates, LLC. All rights reserved.
// Use of this source code is governed by The MIT License, which
// can be found in the LICENSE file included.

package vestigo

import (
"context"
"net/http"
"net/url"
"strings"
)

// methods - a list of methods that are allowed
Expand All @@ -29,30 +30,34 @@ var AllowTrace = false

// Param - Get a url parameter by name
func Param(r *http.Request, name string) string {
return r.URL.Query().Get(":" + name)
// use the request context
if v, ok := r.Context().Value("vestigo_" + name).(string); ok {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below.

return v
}
return ""
}

// ParamNames - Get a url parameter name list
func ParamNames(r *http.Request) []string {
var names []string
for k := range r.URL.Query() {
if strings.HasPrefix(k, ":") {
names = append(names, k)
}
if v, ok := r.Context().Value("vestigo_param_names").([]string); ok {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as below.

return v
}
return names
return []string{}
}

// AddParam - Add a vestigo-style parameter to the request -- useful for middleware
// Appends :name=value onto a blank request query string or appends &:name=value
// onto a non-blank request query string
func AddParam(r *http.Request, name, value string) {
q := url.QueryEscape(":"+name) + "=" + url.QueryEscape(value)
if r.URL.RawQuery != "" {
r.URL.RawQuery += "&" + q
} else {
r.URL.RawQuery += q
paramNames := []string{name}
if v, ok := r.Context().Value("vestigo_param_names").([]string); ok {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use a plain string type like that. See https://blog.golang.org/context#TOC_3.2.

for _, vv := range v {
paramNames = append(paramNames, vv)
}
}
ctx := context.WithValue(r.Context(), "vestigo_"+name, value)
ctx = context.WithValue(ctx, "vestigo_param_names", paramNames)
Copy link

@nhooyr nhooyr Nov 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better solution would be to store a map[string]string in the context with context.WithValue and then use that to return the parameter, instead of using context.WithValue for every single parameter. This would also prevent multiple lists from being stored. Above, for every URL parameter, a new context with the slice paramNames and key "vestigo_param_names" is linked with the others. It would also let you have a single package level key to access the map and prevent collisions. See https://github.com/pressly/chi/blob/master/context.go#L50 for an example.

*r = *r.WithContext(ctx)
}

//validMethod - validate that the http method is valid.
Expand Down
64 changes: 64 additions & 0 deletions common_legacy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// +build !go1.7

// Copyright 2015 Husobee Associates, LLC. All rights reserved.
// Use of this source code is governed by The MIT License, which
// can be found in the LICENSE file included.

package vestigo

import (
"net/http"
"net/url"
"strings"
)

// methods - a list of methods that are allowed
var methods = map[string]bool{
http.MethodConnect: true,
http.MethodDelete: true,
http.MethodGet: true,
http.MethodHead: true,
http.MethodOptions: true,
http.MethodPatch: true,
http.MethodPost: true,
http.MethodPut: true,
http.MethodTrace: true,
}

// AllowTrace - Globally allow the TRACE method handling within vestigo url router. This
// generally not a good idea to have true in production settings, but excellent for testing.
var AllowTrace = false

// Param - Get a url parameter by name
func Param(r *http.Request, name string) string {
return r.URL.Query().Get(":" + name)
}

// ParamNames - Get a url parameter name list
func ParamNames(r *http.Request) []string {
var names []string
for k := range r.URL.Query() {
if strings.HasPrefix(k, ":") {
names = append(names, k)
}
}
return names
}

// AddParam - Add a vestigo-style parameter to the request -- useful for middleware
// Appends :name=value onto a blank request query string or appends &:name=value
// onto a non-blank request query string
func AddParam(r *http.Request, name, value string) {
q := url.QueryEscape(":"+name) + "=" + url.QueryEscape(value)
if r.URL.RawQuery != "" {
r.URL.RawQuery += "&" + q
} else {
r.URL.RawQuery += q
}
}

//validMethod - validate that the http method is valid.
func validMethod(method string) bool {
_, ok := methods[method]
return ok
}
37 changes: 37 additions & 0 deletions router_go17_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//+build go1.7
package vestigo

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetParamNotExists(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?group=2", nil)
// shouldnt exist
val := Param(r, "location")
assert.Equal(t, "", val)
}

func TestParamNames(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?group=2", nil)
AddParam(r, "user", "test")
AddParam(r, "location", "San Francisco, CA")
actual := ParamNames(r)

var foundLocation bool
var foundUser bool
for _, v := range actual {
if v == "user" {
foundUser = true
}
if v == "location" {
foundLocation = true
}
}

assert.Equal(t, foundUser, true)
assert.Equal(t, foundLocation, true)
}
72 changes: 72 additions & 0 deletions router_legacy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
//+build !go1.7

package vestigo

import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAddParamEncode(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?:user=1", nil)
AddParam(r, "id", "2 2")
assert.Equal(t, r.URL.RawQuery, ":user=1&%3Aid=2+2")
}

func TestParamNames(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?:user=1&group=2", nil)
AddParam(r, "location", "San Francisco, CA")
actual := ParamNames(r)

var foundLocation bool
var foundUser bool
for _, v := range actual {
if v == ":user" {
foundUser = true
}
if v == ":location" {
foundLocation = true
}
}

assert.Equal(t, foundUser, true)
assert.Equal(t, foundLocation, true)
}

func TestRouterParamGet(t *testing.T) {
r := NewRouter()
r.Add("GET", "/users/:uid", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "222", r.URL.Query().Get(":uid"))
assert.Equal(t, "222", Param(r, "uid"))
assert.Equal(t, "red", r.URL.Query().Get("color"))
assert.Equal(t, "burger", r.URL.Query().Get("food"))
})

req, _ := http.NewRequest("GET", "/users/222?color=red&food=burger", nil)
h := httptest.NewRecorder()
r.ServeHTTP(h, req)
}

func TestRouterParamPost(t *testing.T) {
r := NewRouter()
r.Add("POST", "/users/:uid", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "123", r.FormValue("id"))
assert.Equal(t, "123", r.Form.Get("id"))
assert.Equal(t, "222", r.URL.Query().Get(":uid"))
assert.Equal(t, "222", Param(r, "uid"))
assert.Equal(t, "red", r.URL.Query().Get("color"))
assert.Equal(t, "burger", r.URL.Query().Get("food"))
})

form := url.Values{}
form.Add("id", "123")
req, _ := http.NewRequest("POST", "/users/222?color=red&food=burger", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
h := httptest.NewRecorder()
r.ServeHTTP(h, req)
}
28 changes: 0 additions & 28 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,32 +316,6 @@ func TestRouterTwoParam(t *testing.T) {
}
}

func TestAddParamEncode(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?:user=1", nil)
AddParam(r, "id", "2 2")
assert.Equal(t, r.URL.RawQuery, ":user=1&%3Aid=2+2")
}

func TestParamNames(t *testing.T) {
r, _ := http.NewRequest("GET", "/test?:user=1&group=2", nil)
AddParam(r, "location", "San Francisco, CA")
actual := ParamNames(r)

var foundLocation bool
var foundUser bool
for _, v := range actual {
if v == ":user" {
foundUser = true
}
if v == ":location" {
foundLocation = true
}
}

assert.Equal(t, foundUser, true)
assert.Equal(t, foundLocation, true)
}

/*
func TestRouterMatchAny(t *testing.T) {
r := NewRouter()
Expand Down Expand Up @@ -567,7 +541,6 @@ func TestRouterParamNames(t *testing.T) {
func TestRouterParamGet(t *testing.T) {
r := NewRouter()
r.Add("GET", "/users/:uid", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "222", r.URL.Query().Get(":uid"))
assert.Equal(t, "222", Param(r, "uid"))
assert.Equal(t, "red", r.URL.Query().Get("color"))
assert.Equal(t, "burger", r.URL.Query().Get("food"))
Expand All @@ -583,7 +556,6 @@ func TestRouterParamPost(t *testing.T) {
r.Add("POST", "/users/:uid", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "123", r.FormValue("id"))
assert.Equal(t, "123", r.Form.Get("id"))
assert.Equal(t, "222", r.URL.Query().Get(":uid"))
assert.Equal(t, "222", Param(r, "uid"))
assert.Equal(t, "red", r.URL.Query().Get("color"))
assert.Equal(t, "burger", r.URL.Query().Get("food"))
Expand Down