diff --git a/CHANGELOG.md b/CHANGELOG.md index 7888f074..10b362d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +## [0.15.0] - 2023-09-26 + +- Added a `Cache-Control` header to `/jwt/jwks.json` (`GetJWKSGET`) +- Added `ValidityInSeconds` to the return value of the overrideable `GetJWKS` function. + - This can be used to control the `Cache-Control` header mentioned above. + - It defaults to `60` or the value set in the cache-control header returned by the core + - This is optional (so you are not required to update your overrides). Returning undefined means that the header is not set. - Handle AWS Public URLs (ending with `.amazonaws.com`) separately while extracting TLDs for SameSite attribute. - Updates fiber adaptor package in the fiber example. diff --git a/recipe/jwt/api/implementation.go b/recipe/jwt/api/implementation.go index 0c93fa83..e6e96fc0 100644 --- a/recipe/jwt/api/implementation.go +++ b/recipe/jwt/api/implementation.go @@ -16,6 +16,7 @@ package api import ( + "fmt" "github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels" "github.com/supertokens/supertokens-golang/supertokens" ) @@ -26,8 +27,13 @@ func MakeAPIImplementation() jwtmodels.APIInterface { if err != nil { return jwtmodels.GetJWKSAPIResponse{}, err } + options.Res.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", resp.OK.ValidityInSeconds)) return jwtmodels.GetJWKSAPIResponse{ - OK: resp.OK, + OK: &struct { + Keys []jwtmodels.JsonWebKeys + }{ + Keys: resp.OK.Keys, + }, }, nil } return jwtmodels.APIInterface{ diff --git a/recipe/jwt/getJWKS_test.go b/recipe/jwt/getJWKS_test.go index c139ab9e..371ab3db 100644 --- a/recipe/jwt/getJWKS_test.go +++ b/recipe/jwt/getJWKS_test.go @@ -127,4 +127,84 @@ func TestDefaultGetJWKSWorksFine(t *testing.T) { result := *unittesting.HttpResponseToConsumableInformation(resp.Body) assert.NotNil(t, result) assert.Greater(t, len(result["keys"].([]interface{})), 0) + + cacheControl := resp.Header.Get("Cache-Control") + assert.Equal(t, cacheControl, "max-age=60, must-revalidate") +} + +func TestThatWeCanOverrideCacheControlThroughRecipeFunction(t *testing.T) { + configValue := supertokens.TypeInput{ + Supertokens: &supertokens.ConnectionInfo{ + ConnectionURI: "http://localhost:8080", + }, + AppInfo: supertokens.AppInfo{ + APIDomain: "api.supertokens.io", + AppName: "SuperTokens", + WebsiteDomain: "supertokens.io", + }, + RecipeList: []supertokens.Recipe{ + Init(&jwtmodels.TypeInput{ + Override: &jwtmodels.OverrideStruct{ + Functions: func(originalImplementation jwtmodels.RecipeInterface) jwtmodels.RecipeInterface { + originalGetJWKS := *originalImplementation.GetJWKS + + getJWKs := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) { + result, err := originalGetJWKS(userContext) + + if err != nil { + return jwtmodels.GetJWKSResponse{}, err + } + + return jwtmodels.GetJWKSResponse{ + OK: &struct { + Keys []jwtmodels.JsonWebKeys + ValidityInSeconds int + }{Keys: result.OK.Keys, ValidityInSeconds: 1234}, + }, nil + } + + *originalImplementation.GetJWKS = getJWKs + + return originalImplementation + }, + }, + }), + }, + } + + BeforeEach() + unittesting.StartUpST("localhost", "8080") + defer AfterEach() + err := supertokens.Init(configValue) + if err != nil { + t.Error(err.Error()) + } + + q, err := supertokens.GetNewQuerierInstanceOrThrowError("") + if err != nil { + t.Error(err.Error()) + } + apiV, err := q.GetQuerierAPIVersion() + if err != nil { + t.Error(err.Error()) + } + + if unittesting.MaxVersion(apiV, "2.8") == "2.8" { + return + } + mux := http.NewServeMux() + testServer := httptest.NewServer(supertokens.Middleware(mux)) + defer testServer.Close() + + resp, err := http.Get(testServer.URL + "/auth/jwt/jwks.json") + if err != nil { + t.Error(err.Error()) + } + + result := *unittesting.HttpResponseToConsumableInformation(resp.Body) + assert.NotNil(t, result) + assert.Greater(t, len(result["keys"].([]interface{})), 0) + + cacheControl := resp.Header.Get("Cache-Control") + assert.Equal(t, cacheControl, "max-age=1234, must-revalidate") } diff --git a/recipe/jwt/jwtmodels/recipeInterface.go b/recipe/jwt/jwtmodels/recipeInterface.go index c3367397..d64d1587 100644 --- a/recipe/jwt/jwtmodels/recipeInterface.go +++ b/recipe/jwt/jwtmodels/recipeInterface.go @@ -31,6 +31,7 @@ type CreateJWTResponse struct { type GetJWKSResponse struct { OK *struct { - Keys []JsonWebKeys + Keys []JsonWebKeys + ValidityInSeconds int } } diff --git a/recipe/jwt/recipeimplementation.go b/recipe/jwt/recipeimplementation.go index 48728df5..4fd7dc21 100644 --- a/recipe/jwt/recipeimplementation.go +++ b/recipe/jwt/recipeimplementation.go @@ -18,8 +18,12 @@ package jwt import ( "github.com/supertokens/supertokens-golang/recipe/jwt/jwtmodels" "github.com/supertokens/supertokens-golang/supertokens" + "regexp" + "strconv" ) +var defaultJWKSMaxAge = 60 // This corresponds to the dynamicSigningKeyOverlapMS in the core + func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.TypeNormalisedInput, appInfo supertokens.NormalisedAppinfo) jwtmodels.RecipeInterface { createJWT := func(payload map[string]interface{}, validitySecondsPointer *uint64, useStaticSigningKey *bool, userContext supertokens.UserContext) (jwtmodels.CreateJWTResponse, error) { validitySeconds := config.JwtValiditySeconds @@ -61,7 +65,7 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type } } getJWKS := func(userContext supertokens.UserContext) (jwtmodels.GetJWKSResponse, error) { - response, err := querier.SendGetRequest("/.well-known/jwks.json", map[string]string{}) + response, headers, err := querier.SendGetRequestWithResponseHeaders("/.well-known/jwks.json", map[string]string{}) if err != nil { return jwtmodels.GetJWKSResponse{}, err } @@ -79,9 +83,29 @@ func makeRecipeImplementation(querier supertokens.Querier, config jwtmodels.Type }) } + validityInSeconds := defaultJWKSMaxAge + cacheControlHeader := headers.Get("Cache-Control") + + if cacheControlHeader != "" { + regex := regexp.MustCompile(`/,?\s*max-age=(\d+)(?:,|$)/`) + maxAgeHeader := regex.FindAllString(cacheControlHeader, -1) + + if maxAgeHeader != nil && len(maxAgeHeader) > 0 { + validityInSeconds, err = strconv.Atoi(maxAgeHeader[1]) + + if err != nil { + validityInSeconds = defaultJWKSMaxAge + } + } + } + return jwtmodels.GetJWKSResponse{ - OK: &struct{ Keys []jwtmodels.JsonWebKeys }{ - Keys: keys, + OK: &struct { + Keys []jwtmodels.JsonWebKeys + ValidityInSeconds int + }{ + Keys: keys, + ValidityInSeconds: validityInSeconds, }, }, nil } diff --git a/supertokens/constants.go b/supertokens/constants.go index adcff424..302a697a 100644 --- a/supertokens/constants.go +++ b/supertokens/constants.go @@ -21,7 +21,7 @@ const ( ) // VERSION current version of the lib -const VERSION = "0.14.0" +const VERSION = "0.15.0" var ( cdiSupported = []string{"3.0"} diff --git a/supertokens/querier.go b/supertokens/querier.go index 725195be..263442f8 100644 --- a/supertokens/querier.go +++ b/supertokens/querier.go @@ -56,7 +56,7 @@ func (q *Querier) GetQuerierAPIVersion() (string, error) { if querierAPIVersion != "" { return querierAPIVersion, nil } - response, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) { + response, _, err := q.sendRequestHelper(NormalisedURLPath{value: "/apiversion"}, func(url string) (*http.Response, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, err @@ -117,7 +117,7 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map if err != nil { return nil, err } - return q.sendRequestHelper(nP, func(url string) (*http.Response, error) { + resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) { if data == nil { data = map[string]interface{}{} } @@ -147,6 +147,7 @@ func (q *Querier) SendPostRequest(path string, data map[string]interface{}) (map client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) + return resp, err } func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, params map[string]string) (map[string]interface{}, error) { @@ -154,7 +155,7 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa if err != nil { return nil, err } - return q.sendRequestHelper(nP, func(url string) (*http.Response, error) { + resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) { jsonData, err := json.Marshal(data) if err != nil { return nil, err @@ -188,6 +189,7 @@ func (q *Querier) SendDeleteRequest(path string, data map[string]interface{}, pa client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) + return resp, err } func (q *Querier) SendGetRequest(path string, params map[string]string) (map[string]interface{}, error) { @@ -195,6 +197,43 @@ func (q *Querier) SendGetRequest(path string, params map[string]string) (map[str if err != nil { return nil, err } + resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + + query := req.URL.Query() + + for k, v := range params { + query.Add(k, v) + } + req.URL.RawQuery = query.Encode() + + apiVerion, querierAPIVersionError := q.GetQuerierAPIVersion() + if querierAPIVersionError != nil { + return nil, querierAPIVersionError + } + req.Header.Set("cdi-version", apiVerion) + if QuerierAPIKey != nil { + req.Header.Set("api-key", *QuerierAPIKey) + } + if nP.IsARecipePath() && q.RIDToCore != "" { + req.Header.Set("rid", q.RIDToCore) + } + + client := &http.Client{} + return client.Do(req) + }, len(QuerierHosts), nil) + return resp, err +} + +func (q *Querier) SendGetRequestWithResponseHeaders(path string, params map[string]string) (map[string]interface{}, http.Header, error) { + nP, err := NewNormalisedURLPath(path) + if err != nil { + return nil, nil, err + } + return q.sendRequestHelper(nP, func(url string) (*http.Response, error) { req, err := http.NewRequest("GET", url, nil) if err != nil { @@ -230,7 +269,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[ if err != nil { return nil, err } - return q.sendRequestHelper(nP, func(url string) (*http.Response, error) { + resp, _, err := q.sendRequestHelper(nP, func(url string) (*http.Response, error) { jsonData, err := json.Marshal(data) if err != nil { return nil, err @@ -257,6 +296,7 @@ func (q *Querier) SendPutRequest(path string, data map[string]interface{}) (map[ client := &http.Client{} return client.Do(req) }, len(QuerierHosts), nil) + return resp, err } type httpRequestFunction func(url string) (*http.Response, error) @@ -279,9 +319,9 @@ func GetAllCoreUrlsForPath(path string) []string { return result } -func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, error) { +func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequestFunction, numberOfTries int, retryInfoMap *map[string]int) (map[string]interface{}, http.Header, error) { if numberOfTries == 0 { - return nil, errors.New("no SuperTokens core available to query") + return nil, nil, errors.New("no SuperTokens core available to query") } querierHostLock.Lock() @@ -316,14 +356,14 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ if resp != nil { resp.Body.Close() } - return nil, err + return nil, nil, err } defer resp.Body.Close() body, readErr := ioutil.ReadAll(resp.Body) if readErr != nil { - return nil, readErr + return nil, nil, readErr } if resp.StatusCode != 200 { if resp.StatusCode == RateLimitStatusCode { @@ -341,17 +381,18 @@ func (q *Querier) sendRequestHelper(path NormalisedURLPath, httpRequest httpRequ } } - return nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body) + return nil, nil, fmt.Errorf("SuperTokens core threw an error for a request to path: '%s' with status code: %v and message: %s", path.GetAsStringDangerous(), resp.StatusCode, body) } + headers := resp.Header.Clone() finalResult := make(map[string]interface{}) jsonError := json.Unmarshal(body, &finalResult) if jsonError != nil { return map[string]interface{}{ "result": string(body), - }, nil + }, headers, nil } - return finalResult, nil + return finalResult, headers, nil } func ResetQuerierForTest() {