From 326d38b84a5df8cb0aa0aee9f056d5d1d0aa8bd7 Mon Sep 17 00:00:00 2001 From: Giannis Katsanos Date: Mon, 12 Feb 2024 11:43:40 +0200 Subject: [PATCH] feat: Cache requests for JWKS on JWT verification (#228) The jwt.Verify method needs to fetch the JSON Web Key Set from the API in order to verify the session JWT's validity. The jwt.Verify method is used in the http.WithHeaderAuthorization middleware, which means that in an HTTP server context, the method will executed for every request. We're adding a caching layer for the JWKS when we verify the session JWT. This way we can cache the JWKS response from the API for 1 hour. --- jwt/jwt.go | 107 ++++++++++++++++++++++++++++++++++++++++++++++-- jwt/jwt_test.go | 75 +++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 3 deletions(-) create mode 100644 jwt/jwt_test.go diff --git a/jwt/jwt.go b/jwt/jwt.go index f1fbb635..52c7e559 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "strings" + "sync" "time" "github.com/clerk/clerk-sdk-go/v2" @@ -103,10 +104,9 @@ func Verify(ctx context.Context, params *VerifyParams) (*clerk.SessionClaims, er return claims, nil } +// Retrieve the JSON web key for the provided id from the set. func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) { - // TODO Avoid multiple requests by caching results for the same - // instance. - jwks, err := jwks.Get(ctx, &jwks.GetParams{}) + jwks, err := getJWKSWithCache(ctx) if err != nil { return nil, err } @@ -118,6 +118,44 @@ func getJWK(ctx context.Context, kid string) (*clerk.JSONWebKey, error) { return nil, fmt.Errorf("no jwk key found for kid %s", kid) } +// Returns the JSON web key set. Tries a cached value first, but if +// there's no value or the entry has expired, it will fetch the set +// from the API and cache the value. +func getJWKSWithCache(ctx context.Context) (*clerk.JSONWebKeySet, error) { + const cacheKey = "/v1/jwks" + var jwks *clerk.JSONWebKeySet + var err error + + // Try the cache first. Make sure we have a non-expired entry and + // that the value is a valid JWKS. + entry, ok := getCache().Get(cacheKey) + if ok && !entry.HasExpired() { + jwks, ok = entry.GetValue().(*clerk.JSONWebKeySet) + if !ok || jwks == nil || len(jwks.Keys) == 0 { + jwks, err = forceGetJWKS(ctx, cacheKey) + if err != nil { + return nil, err + } + } + } else { + jwks, err = forceGetJWKS(ctx, cacheKey) + if err != nil { + return nil, err + } + } + return jwks, err +} + +// Fetches the JSON web key set from the API and caches it. +func forceGetJWKS(ctx context.Context, cacheKey string) (*clerk.JSONWebKeySet, error) { + jwks, err := jwks.Get(ctx, &jwks.GetParams{}) + if err != nil { + return nil, err + } + getCache().Set(cacheKey, jwks, time.Now().UTC().Add(time.Hour)) + return jwks, nil +} + func isValidIssuer(iss string) bool { return strings.HasPrefix(iss, "https://clerk.") || strings.Contains(iss, ".clerk.accounts") @@ -154,3 +192,66 @@ func Decode(_ context.Context, params *DecodeParams) (*clerk.Claims, error) { Extra: extraClaims, }, nil } + +// Caching store. +type cache struct { + mu sync.RWMutex + entries map[string]*cacheEntry +} + +// Get returns the cache entry for the provided key, if one exists. +func (c *cache) Get(key string) (*cacheEntry, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + entry, ok := c.entries[key] + return entry, ok +} + +// Set adds a new entry with the provided value in the cache under +// the provided key. An expiration date will be set for the entry. +func (c *cache) Set(key string, value any, expiresAt time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + c.entries[key] = &cacheEntry{ + value: value, + expiresAt: expiresAt, + } +} + +// A cache entry has a value and an expiration date. +type cacheEntry struct { + value any + expiresAt time.Time +} + +// HasExpired returns true if the cache entry's expiration date +// has passed. +func (entry *cacheEntry) HasExpired() bool { + if entry == nil { + return true + } + return entry.expiresAt.Before(time.Now()) +} + +// GetValue returns the cache entry's value. +func (entry *cacheEntry) GetValue() any { + if entry == nil { + return nil + } + return entry.value +} + +var cacheInit sync.Once + +// A "singleton" cache for the package. +var defaultCache *cache + +// Lazy initialize and return the default cache singleton. +func getCache() *cache { + cacheInit.Do(func() { + defaultCache = &cache{ + entries: map[string]*cacheEntry{}, + } + }) + return defaultCache +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 00000000..a2673819 --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,75 @@ +package jwt + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/clerk/clerk-sdk-go/v2" + "github.com/clerk/clerk-sdk-go/v2/clerktest" + "github.com/stretchr/testify/require" +) + +func TestVerify_InvalidToken(t *testing.T) { + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: &http.Client{ + Transport: &clerktest.RoundTripper{}, + }, + })) + + ctx := context.Background() + _, err := Verify(ctx, &VerifyParams{ + Token: "this-is-not-a-token", + }) + require.Error(t, err) +} + +func TestVerify_Cache(t *testing.T) { + ctx := context.Background() + totalRequests := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && r.URL.Path == "/v1/jwks" { + totalRequests++ + } + _, err := w.Write([]byte(`{ + "keys": [{ + "use": "sig", + "kty": "RSA", + "kid": "ins_123", + "alg": "RS256", + "n": "9m1LJW0dgEuK8SnN1Oy4LY8vaWABVS-hBTMA--_4LN1PZlMS5B2RPL85WkXYlHb0KXOSVrFKZLwYP-a9l3MFlW2YrPVAIvYfqPyqY5fmSEf-2qfrwosIhB2NSHyNRBQQ8-BX1RO9rIXIqYDKxGqktqMvYJmEGClmijbmFyQb2hpHD5PDbAB_DZvpZTEzWcQBL2ytHehILkYfg-ZZRyt7O8h5Gdy1v_TUlg8iMvchHlAkrIAmXNQigZmX_lne91tW8t4KMNJRfmUyLVCLbPnwxlmXXcice-0tmFw0OkCOteNWBeRNctJ3AIreGMzaJOJ2HeSUmJoX8iRKLLT3fsURLw", + "e": "AQAB" + }] +}`)) + require.NoError(t, err) + })) + defer ts.Close() + + clerk.SetBackend(clerk.NewBackend(&clerk.BackendConfig{ + HTTPClient: ts.Client(), + URL: clerk.String(ts.URL), + })) + + token := "eyJhbGciOiJSUzI1NiIsImNhdCI6ImNsX0I3ZDRQRDExMUFBQSIsImtpZCI6Imluc18yOWR6bUdmQ3JydzdSMDRaVFFZRDNKSTB5dkYiLCJ0eXAiOiJKV1QifQ.eyJhenAiOiJodHRwczovL2Rhc2hib2FyZC5wcm9kLmxjbGNsZXJrLmNvbSIsImV4cCI6MTcwNzMwMDMyMiwiaWF0IjoxNzA3MzAwMjYyLCJpc3MiOiJodHRwczovL2NsZXJrLnByb2QubGNsY2xlcmsuY29tIiwibmJmIjoxNzA3MzAwMjUyLCJvcmdzIjp7Im9yZ18ySUlwcVIxenFNeHJQQkhSazNzTDJOSnJUQkQiOiJvcmc6YWRtaW4iLCJvcmdfMllHMlNwd0IzWEJoNUo0ZXF5elFVb0dXMjVhIjoib3JnOmFkbWluIiwib3JnXzJhZzJ6bmgxWGFjTXI0dGRXYjZRbEZSQ2RuaiI6Im9yZzphZG1pbiIsIm9yZ18yYWlldHlXa3VFSEhaRmRSUTFvVjYzMnZWaFciOiJvcmc6YWRtaW4ifSwic2lkIjoic2Vzc18yYm84b2gyRnIyeTNueVoyRVZQYktBd2ZvaU0iLCJzdWIiOiJ1c2VyXzI5ZTBXTnp6M245V1Q5S001WlpJYTBVVjNDNyJ9.6GtQafMBYY3Ij3pKHOyBYKt76LoLeBC71QUY_ho3k5nb0FBSvV0upKFLPBvIXNuF7hH0FK2QqDcAmrhbzAI-2qF_Ynve8Xl4VZCRpbTuZI7uL-tVjCvMffEIH-BHtrZ-QcXhEmNFQNIPyZTu21242he7U6o4S8st_aLmukWQzj_4qir7o5_fmVhm7YkLa0gYG5SLjkr2czwem1VGFHEVEOrHjun-g6eMnDNMMMysIOkZFxeqiCnqpc4u1V7Z7jfoK0r_-Unp8mGGln5KWYMCQyp1l1SkGwugtxeWfSbE4eklKRmItGOdVftvTyG16kDGpzsb22AQGtg65Iygni4PHg" + // Providing a custom key will not trigger a request to fetch the + // key set. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + JWK: &clerk.JSONWebKey{}, + }) + require.Equal(t, 0, totalRequests) + + // Verify without providing a key. The method will trigger a request + // to fetch the key set. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + }) + require.Equal(t, 1, totalRequests) + // Verifying again won't trigger a request because the key set is + // cached. + _, _ = Verify(ctx, &VerifyParams{ + Token: token, + }) + require.Equal(t, 1, totalRequests) +}