diff --git a/csrf.go b/csrf.go index b7c4733..361bfd1 100644 --- a/csrf.go +++ b/csrf.go @@ -6,9 +6,8 @@ package csrf import ( - "crypto/rand" "fmt" - r "math/rand" + "math/rand" "net/http" "reflect" "time" @@ -104,23 +103,30 @@ type Options struct { ErrorFunc func(w http.ResponseWriter) } +var src = rand.NewSource(time.Now().UnixNano()) + // randomBytes generates n random []byte. func randomBytes(n int) []byte { - const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - var bytes = make([]byte, n) - var randby bool - if num, err := rand.Read(bytes); num != n || err != nil { - r.Seed(time.Now().UnixNano()) - randby = true - } - for i, b := range bytes { - if randby { - bytes[i] = alphanum[r.Intn(len(alphanum))] - } else { - bytes[i] = alphanum[b%byte(len(alphanum))] + const ( + letterBytes = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + if idx := int(cache & letterIdxMask); idx < len(letterBytes) { + b[i] = letterBytes[idx] + i-- } + cache >>= letterIdxBits + remain-- } - return bytes + return b } var _ inject.FastInvoker = (*csrfInvoker)(nil)