diff --git a/token.go b/token.go index 3783dec..28d0172 100644 --- a/token.go +++ b/token.go @@ -58,17 +58,24 @@ func b64decode(data string) []byte { // Supports masked tokens. realToken comes from Token(r) and // sentToken is token sent unusual way. func VerifyToken(realToken, sentToken string) bool { - r := b64decode(realToken) + r, err := base64.StdEncoding.DecodeString(realToken) + if err != nil { + return false + } if len(r) == 2*tokenLength { r = unmaskToken(r) } - s := b64decode(sentToken) + s, err := base64.StdEncoding.DecodeString(sentToken) + if err != nil { + return false + } if len(s) == 2*tokenLength { s = unmaskToken(s) } - return subtle.ConstantTimeCompare(r, s) == 1 + return tokensEqual(r, s) } +// verifyToken expects the realToken to be unmasked and the sentToken to be masked func verifyToken(realToken, sentToken []byte) bool { realN := len(realToken) sentN := len(sentToken) @@ -77,15 +84,16 @@ func verifyToken(realToken, sentToken []byte) bool { // sentN == 2*tokenLength means the token is masked. if realN == tokenLength && sentN == 2*tokenLength { - return verifyMasked(realToken, sentToken) + return tokensEqual(realToken, unmaskToken(sentToken)) } return false } -// Verifies the masked token -func verifyMasked(realToken, sentToken []byte) bool { - sentPlain := unmaskToken(sentToken) - return subtle.ConstantTimeCompare(realToken, sentPlain) == 1 +// tokensEqual expects both tokens to be unmasked +func tokensEqual(realToken, sentToken []byte) bool { + return len(realToken) == tokenLength && + len(sentToken) == tokenLength && + subtle.ConstantTimeCompare(realToken, sentToken) == 1 } func checkForPRNG() { diff --git a/token_test.go b/token_test.go index 2a848db..66ddc18 100644 --- a/token_test.go +++ b/token_test.go @@ -2,6 +2,7 @@ package nosurf import ( "crypto/rand" + "encoding/base64" "testing" ) @@ -70,3 +71,60 @@ func TestVerifiesMaskedTokenCorrectly(t *testing.T) { t.Errorf("VerifyToken returned a false positive") } } + +func TestVerifyTokenBase64Invalid(t *testing.T) { + for _, pairs := range [][]string{ + {"foo", "bar"}, + {"foo", ""}, + {"", "bar"}, + {"", ""}, + } { + if VerifyToken(pairs[0], pairs[1]) { + t.Errorf("VerifyToken returned a false positive for: %v", pairs) + } + } +} + +func TestVerifyTokenUnMasked(t *testing.T) { + for i, tc := range []struct { + real string + send string + valid bool + }{ + { + real: "qwertyuiopasdfghjklzxcvbnm123456", + send: "qwertyuiopasdfghjklzxcvbnm123456", + valid: true, + }, + { + real: "qwertyuiopasdfghjklzxcvbnm123456", + send: "qwertyuiopasdfghjklzxcvbnm123456" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + valid: true, + }, + { + real: "qwertyuiopasdfghjklzxcvbnm123456" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + send: "qwertyuiopasdfghjklzxcvbnm123456" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + valid: true, + }, + { + real: "qwertyuiopasdfghjklzxcvbnm123456" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + send: "qwertyuiopasdfghjklzxcvbnm123456", + valid: true, + }, + } { + if VerifyToken( + base64.StdEncoding.EncodeToString([]byte(tc.real)), + base64.StdEncoding.EncodeToString([]byte(tc.send)), + ) != tc.valid { + t.Errorf("Verify token returned wrong result for case %d: %+v", i, tc) + } + } +}