Skip to content

Commit

Permalink
Merge pull request #368 from supertokens/jwks-cache
Browse files Browse the repository at this point in the history
chore: Update logic of get jwks to include cache control header handling
  • Loading branch information
rishabhpoddar authored Sep 26, 2023
2 parents f842d98 + 0a682de commit a311ca0
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 17 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]

## [0.15.0] - 2023-09-26

- Added a `Cache-Control` header to `/jwt/jwks.json` (`GetJWKSGET`)
- Added `ValidityInSeconds` to the return value of the overrideable `GetJWKS` function.
- This can be used to control the `Cache-Control` header mentioned above.
- It defaults to `60` or the value set in the cache-control header returned by the core
- This is optional (so you are not required to update your overrides). Returning undefined means that the header is not set.
- Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute.
- Updates fiber adaptor package in the fiber example.

Expand Down
8 changes: 7 additions & 1 deletion recipe/jwt/api/implementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package api

import (
"fmt"
"github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels"
"github.com/supertokens/supertokens-golang/supertokens"
)
Expand All @@ -26,8 +27,13 @@ func MakeAPIImplementation() jwtmodels.APIInterface {
if err != nil {
return jwtmodels.GetJWKSAPIResponse{}, err
}
options.Res.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", resp.OK.ValidityInSeconds))
return jwtmodels.GetJWKSAPIResponse{
OK: resp.OK,
OK: &struct {
Keys []jwtmodels.JsonWebKeys
}{
Keys: resp.OK.Keys,
},
}, nil
}
return jwtmodels.APIInterface{
Expand Down
80 changes: 80 additions & 0 deletions recipe/jwt/getJWKS_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,84 @@ func TestDefaultGetJWKSWorksFine(t *testing.T) {
result := *unittesting.HttpResponseToConsumableInformation(resp.Body)
assert.NotNil(t, result)
assert.Greater(t, len(result["keys"].([]interface{})), 0)

cacheControl := resp.Header.Get("Cache-Control")
assert.Equal(t, cacheControl, "max-age=60, must-revalidate")
}

func TestThatWeCanOverrideCacheControlThroughRecipeFunction(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
},
AppInfo: supertokens.AppInfo{
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(&jwtmodels.TypeInput{
Override: &jwtmodels.OverrideStruct{
Functions: func(originalImplementation jwtmodels.RecipeInterface) jwtmodels.RecipeInterface {
originalGetJWKS := *originalImplementation.GetJWKS

getJWKs := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) {
result, err := originalGetJWKS(userContext)

if err != nil {
return jwtmodels.GetJWKSResponse{}, err
}

return jwtmodels.GetJWKSResponse{
OK: &struct {
Keys []jwtmodels.JsonWebKeys
ValidityInSeconds int
}{Keys: result.OK.Keys, ValidityInSeconds: 1234},
}, nil
}

*originalImplementation.GetJWKS = getJWKs

return originalImplementation
},
},
}),
},
}

BeforeEach()
unittesting.StartUpST("localhost", "8080")
defer AfterEach()
err := supertokens.Init(configValue)
if err != nil {
t.Error(err.Error())
}

q, err := supertokens.GetNewQuerierInstanceOrThrowError("")
if err != nil {
t.Error(err.Error())
}
apiV, err := q.GetQuerierAPIVersion()
if err != nil {
t.Error(err.Error())
}

if unittesting.MaxVersion(apiV, "2.8") == "2.8" {
return
}
mux := http.NewServeMux()
testServer := httptest.NewServer(supertokens.Middleware(mux))
defer testServer.Close()

resp, err := http.Get(testServer.URL + "/auth/jwt/jwks.json")
if err != nil {
t.Error(err.Error())
}

result := *unittesting.HttpResponseToConsumableInformation(resp.Body)
assert.NotNil(t, result)
assert.Greater(t, len(result["keys"].([]interface{})), 0)

cacheControl := resp.Header.Get("Cache-Control")
assert.Equal(t, cacheControl, "max-age=1234, must-revalidate")
}
3 changes: 2 additions & 1 deletion recipe/jwt/jwtmodels/recipeInterface.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type CreateJWTResponse struct {

type GetJWKSResponse struct {
OK *struct {
Keys []JsonWebKeys
Keys []JsonWebKeys
ValidityInSeconds int
}
}
30 changes: 27 additions & 3 deletions recipe/jwt/recipeimplementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ package jwt
import (
"github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels"
"github.com/supertokens/supertokens-golang/supertokens"
"regexp"
"strconv"
)

var defaultJWKSMaxAge = 60 // This corresponds to the dynamicSigningKeyOverlapMS in the core

func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) jwtmodels.RecipeInterface {
createJWT := func(payload map[string]interface{}, validitySecondsPointer *uint64, useStaticSigningKey *bool, userContext supertokens.UserContext) (jwtmodels.CreateJWTResponse, error) {
validitySeconds := config.JwtValiditySeconds
Expand Down Expand Up @@ -61,7 +65,7 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type
}
}
getJWKS := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) {
response, err := querier.SendGetRequest("/.well-known/jwks.json", map[string]string{})
response, headers, err := querier.SendGetRequestWithResponseHeaders("/.well-known/jwks.json", map[string]string{})
if err != nil {
return jwtmodels.GetJWKSResponse{}, err
}
Expand All @@ -79,9 +83,29 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type
})
}

validityInSeconds := defaultJWKSMaxAge
cacheControlHeader := headers.Get("Cache-Control")

if cacheControlHeader != "" {
regex := regexp.MustCompile(`/,?\s*max-age=(\d+)(?:,|$)/`)
maxAgeHeader := regex.FindAllString(cacheControlHeader, -1)

if maxAgeHeader != nil && len(maxAgeHeader) > 0 {
validityInSeconds, err = strconv.Atoi(maxAgeHeader[1])

if err != nil {
validityInSeconds = defaultJWKSMaxAge
}
}
}

return jwtmodels.GetJWKSResponse{
OK: &struct{ Keys []jwtmodels.JsonWebKeys }{
Keys: keys,
OK: &struct {
Keys []jwtmodels.JsonWebKeys
ValidityInSeconds int
}{
Keys: keys,
ValidityInSeconds: validityInSeconds,
},
}, nil
}
Expand Down
2 changes: 1 addition & 1 deletion supertokens/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const (
)

// VERSION current version of the lib
const VERSION = "0.14.0"
const VERSION = "0.15.0"

var (
cdiSupported = []string{"3.0"}
Expand Down
63 changes: 52 additions & 11 deletions supertokens/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (q *Querier) GetQuerierAPIVersion() (string, error) {
if querierAPIVersion != "" {
return querierAPIVersion, nil
}
response, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) {
response, _, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
Expand Down Expand Up @@ -117,7 +117,7 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map
if err != nil {
return nil, err
}
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
if data == nil {
data = map[string]interface{}{}
}
Expand Down Expand Up @@ -147,14 +147,15 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map
client := &http.Client{}
return client.Do(req)
}, len(QuerierHosts), nil)
return resp, err
}

func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, params map[string]string) (map[string]interface{}, error) {
nP, err := NewNormalisedURLPath(path)
if err != nil {
return nil, err
}
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
Expand Down Expand Up @@ -188,13 +189,51 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa
client := &http.Client{}
return client.Do(req)
}, len(QuerierHosts), nil)
return resp, err
}

func (q *Querier) SendGetRequest(path string, params map[string]string) (map[string]interface{}, error) {
nP, err := NewNormalisedURLPath(path)
if err != nil {
return nil, err
}
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}

query := req.URL.Query()

for k, v := range params {
query.Add(k, v)
}
req.URL.RawQuery = query.Encode()

apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion()
if querierAPIVersionError != nil {
return nil, querierAPIVersionError
}
req.Header.Set("cdi-version", apiVerion)
if QuerierAPIKey != nil {
req.Header.Set("api-key", *QuerierAPIKey)
}
if nP.IsARecipePath() && q.RIDToCore != "" {
req.Header.Set("rid", q.RIDToCore)
}

client := &http.Client{}
return client.Do(req)
}, len(QuerierHosts), nil)
return resp, err
}

func (q *Querier) SendGetRequestWithResponseHeaders(path string, params map[string]string) (map[string]interface{}, http.Header, error) {
nP, err := NewNormalisedURLPath(path)
if err != nil {
return nil, nil, err
}

return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
Expand Down Expand Up @@ -230,7 +269,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[
if err != nil {
return nil, err
}
return q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
Expand All @@ -257,6 +296,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[
client := &http.Client{}
return client.Do(req)
}, len(QuerierHosts), nil)
return resp, err
}

type httpRequestFunction func(url string) (*http.Response, error)
Expand All @@ -279,9 +319,9 @@ func GetAllCoreUrlsForPath(path string) []string {
return result
}

func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, error) {
func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, http.Header, error) {
if numberOfTries == 0 {
return nil, errors.New("no SuperTokens core available to query")
return nil, nil, errors.New("no SuperTokens core available to query")
}

querierHostLock.Lock()
Expand Down Expand Up @@ -316,14 +356,14 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ
if resp != nil {
resp.Body.Close()
}
return nil, err
return nil, nil, err
}

defer resp.Body.Close()

body, readErr := ioutil.ReadAll(resp.Body)
if readErr != nil {
return nil, readErr
return nil, nil, readErr
}
if resp.StatusCode != 200 {
if resp.StatusCode == RateLimitStatusCode {
Expand All @@ -341,17 +381,18 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ
}
}

return nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body)
return nil, nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body)
}

headers := resp.Header.Clone()
finalResult := make(map[string]interface{})
jsonError := json.Unmarshal(body, &finalResult)
if jsonError != nil {
return map[string]interface{}{
"result": string(body),
}, nil
}, headers, nil
}
return finalResult, nil
return finalResult, headers, nil
}

func ResetQuerierForTest() {
Expand Down

0 comments on commit a311ca0

Please sign in to comment.