Skip to content

Commit

Permalink
fix: jwks cache and test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sattvikc committed Jul 26, 2024
1 parent a651892 commit 29eef08
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,6 @@ releasePassword
.vscode/
.idea/
/test_report

build-errors.log
main
16 changes: 8 additions & 8 deletions recipe/emailpassword/network_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ func TestNetworkInterceptorDuringSignIn(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) *http.Request {
NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) (*http.Request, error) {
isNetworkIntercepted = true
return request
return request, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -105,11 +105,11 @@ func TestNetworkInterceptorIncorrectCoreURL(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) *http.Request {
NetworkInterceptor: func(request *http.Request, context supertokens.UserContext) (*http.Request, error) {
isNetworkIntercepted = true
newRequest := request
newRequest.URL.Path = "/public/recipe/incorrect/path"
return newRequest
return newRequest, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -149,12 +149,12 @@ func TestNetworkInterceptorIncorrectQueryParams(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) (*http.Request, error) {
isNetworkIntercepted = true
newRequest := r
q := url.Values{}
newRequest.URL.RawQuery = q.Encode()
return newRequest
return newRequest, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -191,12 +191,12 @@ func TestNetworkInterceptorRequestBody(t *testing.T) {
configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, context supertokens.UserContext) (*http.Request, error) {
isNetworkIntercepted = true
newBody := bytes.NewReader([]byte(`{"newKey": "newValue"}`))
req, _ := http.NewRequest(r.Method, r.URL.String(), newBody)
req.Header = r.Header
return req
return req, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down
28 changes: 14 additions & 14 deletions recipe/emailpassword/querier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ func TestCachingWorks(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -79,9 +79,9 @@ func TestNoCachingIfDisabledByUser(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
DisableCoreCallCache: true,
},
Expand Down Expand Up @@ -124,9 +124,9 @@ func TestNoCachingIfHeadersAreDifferent(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -176,9 +176,9 @@ func TestCachingGetsClearWhenQueryWithoutUserContext(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -223,9 +223,9 @@ func TestCachingDoesNotGetClearWithNonGetIfKeepAlive(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -295,9 +295,9 @@ func TestCachingGetsClearWithNonGetIfKeepAliveIsFalse(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down Expand Up @@ -367,9 +367,9 @@ func TestCachingGetsClearWithNonGetIfKeepAliveIsNotSet(t *testing.T) {
config := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
ConnectionURI: "http://localhost:8080",
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) *http.Request {
NetworkInterceptor: func(r *http.Request, uc supertokens.UserContext) (*http.Request, error) {
calledCore = true
return r
return r, nil
},
},
AppInfo: supertokens.AppInfo{
Expand Down
3 changes: 3 additions & 0 deletions recipe/passwordless/passwordless_email_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,9 @@ func TestThatMagicLinkUsesRightValueFromOriginFunction(t *testing.T) {
APIDomain: "api.supertokens.io",
AppName: "SuperTokens",
GetOrigin: func(request *http.Request, userContext supertokens.UserContext) (string, error) {
if request == nil {
return "https://supertokens.com", nil
}
// read request body
decoder := json.NewDecoder(request.Body)
var requestBody map[string]interface{}
Expand Down
1 change: 0 additions & 1 deletion recipe/session/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ const (
CookieSameSite_STRICT = "strict"
)

var JWKCacheMaxAgeInMs int64 = 60000
var JWKRefreshRateLimit = 500
var protectedProps = []string{
"sub",
Expand Down
8 changes: 7 additions & 1 deletion recipe/session/recipeImplementation.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ var mutex sync.RWMutex
func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult {
mutex.RLock()
defer mutex.RUnlock()

sessionInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
return nil
}

if jwksCache != nil {
// This means that we have valid JWKs for the given core path
// We check if we need to refresh before returning
Expand All @@ -48,7 +54,7 @@ func getJWKSFromCacheIfPresent() *sessmodels.GetJWKSResult {
// Note that this also means that the SDK will not try to query any other Core (if there are multiple)
// if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
// from the cores again after the entry in the cache is expired
if (currentTime - jwksCache.LastFetched) < JWKCacheMaxAgeInMs {
if (currentTime - jwksCache.LastFetched) < int64(sessionInstance.Config.JWKSRefreshIntervalSec*1000) {
if supertokens.IsRunningInTestMode() {
if len(returnedFromCache) == cap(returnedFromCache) { // need to clear the channel if full because it's not being consumed in the test
close(returnedFromCache)
Expand Down
2 changes: 1 addition & 1 deletion recipe/session/sessionFunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func getSessionHelper(config sessmodels.TypeNormalisedInput, querier supertokens

// We check if the token was created since the last time we refreshed the keys from the core
// Since we do not know the exact timing of the last refresh, we check against the max age
if timeCreated <= (GetCurrTimeInMS() - uint64(JWKCacheMaxAgeInMs)) {
if timeCreated <= (GetCurrTimeInMS() - config.JWKSRefreshIntervalSec*1000) {
return sessmodels.GetSessionResponse{}, err
}
} else {
Expand Down
42 changes: 21 additions & 21 deletions recipe/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1496,10 +1496,9 @@ This test verifies that the SDK calls the well known API properly in the normal
*/
func TestThatJWKSIsFetchedAsExpected(t *testing.T) {
originalRefreshlimit := JWKRefreshRateLimit
originalCacheAge := JWKCacheMaxAgeInMs

JWKRefreshRateLimit = 100
JWKCacheMaxAgeInMs = 2000
var JWKCacheMaxAgeInSec uint64 = 2

lastLineBeforeTest := unittesting.GetInfoLogData(t, "").LastLine

Expand All @@ -1513,7 +1512,9 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) {
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(nil),
Init(&sessmodels.TypeInput{
JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec,
}),
},
}
BeforeEach()
Expand Down Expand Up @@ -1548,7 +1549,7 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) {
t.Error(err.Error())
}

time.Sleep(time.Duration(JWKCacheMaxAgeInMs) * time.Millisecond)
time.Sleep(time.Duration(JWKCacheMaxAgeInSec) * time.Second)

logInfoAfterWaiting := unittesting.GetInfoLogData(t, lastLineBeforeTest)
wellKnownCallLogs = []string{}
Expand All @@ -1562,7 +1563,6 @@ func TestThatJWKSIsFetchedAsExpected(t *testing.T) {
assert.Equal(t, len(wellKnownCallLogs), 1)

JWKRefreshRateLimit = originalRefreshlimit
JWKCacheMaxAgeInMs = originalCacheAge
}

/*
Expand All @@ -1578,10 +1578,9 @@ cache expired and the keys need to be refetched.
*/
func TestThatJWKSResultIsRefreshedProperly(t *testing.T) {
originalRefreshlimit := JWKRefreshRateLimit
originalCacheAge := JWKCacheMaxAgeInMs

JWKRefreshRateLimit = 100
JWKCacheMaxAgeInMs = 2000
JWKCacheMaxAgeInSec := uint64(2)

configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
Expand All @@ -1593,7 +1592,9 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) {
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(nil),
Init(&sessmodels.TypeInput{
JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec,
}),
},
}
BeforeEach()
Expand Down Expand Up @@ -1633,7 +1634,6 @@ func TestThatJWKSResultIsRefreshedProperly(t *testing.T) {

assert.True(t, len(newKeys) != 0)
JWKRefreshRateLimit = originalRefreshlimit
JWKCacheMaxAgeInMs = originalCacheAge
}

/*
Expand Down Expand Up @@ -1794,10 +1794,9 @@ This test verifies the behaviour of the JWKS cache maintained by the SDK
*/
func TestJWKSCacheLogic(t *testing.T) {
originalRefreshlimit := JWKRefreshRateLimit
originalCacheAge := JWKCacheMaxAgeInMs

JWKRefreshRateLimit = 100
JWKCacheMaxAgeInMs = 2000
var JWKCacheMaxAgeInSec uint64 = 2

configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
Expand All @@ -1809,7 +1808,9 @@ func TestJWKSCacheLogic(t *testing.T) {
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(nil),
Init(&sessmodels.TypeInput{
JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec,
}),
},
}
BeforeEach()
Expand Down Expand Up @@ -1849,7 +1850,6 @@ func TestJWKSCacheLogic(t *testing.T) {
assert.NotNil(t, jwksCache)

JWKRefreshRateLimit = originalRefreshlimit
JWKCacheMaxAgeInMs = originalCacheAge
}

/*
Expand Down Expand Up @@ -1940,10 +1940,9 @@ This test ensures that the SDK's caching logic for fetching JWKs works fine
*/
func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) {
originalRefreshlimit := JWKRefreshRateLimit
originalCacheAge := JWKCacheMaxAgeInMs

JWKRefreshRateLimit = 100
JWKCacheMaxAgeInMs = 2000
var JWKCacheMaxAgeInSec uint64 = 2

configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
Expand All @@ -1955,7 +1954,9 @@ func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) {
APIDomain: "api.supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(nil),
Init(&sessmodels.TypeInput{
JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec,
}),
},
}
BeforeEach()
Expand Down Expand Up @@ -2002,7 +2003,6 @@ func TestThatJWKSReturnsFromCacheCorrectly(t *testing.T) {
assert.Equal(t, <-returnedFromCache, false)

JWKRefreshRateLimit = originalRefreshlimit
JWKCacheMaxAgeInMs = originalCacheAge
}

/*
Expand Down Expand Up @@ -2205,10 +2205,9 @@ func TestSessionVerificationOfJWTBasedOnSessionPayloadWithCheckDatabase(t *testi

func TestThatLockingForJWKSCacheWorksFine(t *testing.T) {
originalRefreshlimit := JWKRefreshRateLimit
originalCacheAge := JWKCacheMaxAgeInMs

JWKRefreshRateLimit = 100
JWKCacheMaxAgeInMs = 2000
var JWKCacheMaxAgeInSec uint64 = 2

configValue := supertokens.TypeInput{
Supertokens: &supertokens.ConnectionInfo{
Expand All @@ -2220,7 +2219,9 @@ func TestThatLockingForJWKSCacheWorksFine(t *testing.T) {
WebsiteDomain: "supertokens.io",
},
RecipeList: []supertokens.Recipe{
Init(nil),
Init(&sessmodels.TypeInput{
JWKSRefreshIntervalSec: &JWKCacheMaxAgeInSec,
}),
},
}
BeforeEach()
Expand Down Expand Up @@ -2295,7 +2296,6 @@ func TestThatLockingForJWKSCacheWorksFine(t *testing.T) {
assert.Equal(t, notReturnFromCacheCount, 5)

JWKRefreshRateLimit = originalRefreshlimit
JWKCacheMaxAgeInMs = originalCacheAge
}

func TestThatGetSessionThrowsWIthDynamicKeysIfSessionWasCreatedWithStaticKeys(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions recipe/session/sessmodels/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ type TypeInput struct {
GetTokenTransferMethod func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) TokenTransferMethod
ExposeAccessTokenToFrontendInCookieBasedAuth bool
UseDynamicAccessTokenSigningKey *bool
JWKSRefreshIntervalSec *uint64
}

type OverrideStruct struct {
Expand Down Expand Up @@ -141,6 +142,7 @@ type TypeNormalisedInput struct {
GetTokenTransferMethod func(req *http.Request, forCreateNewSession bool, userContext supertokens.UserContext) TokenTransferMethod
ExposeAccessTokenToFrontendInCookieBasedAuth bool
UseDynamicAccessTokenSigningKey bool
JWKSRefreshIntervalSec uint64
}

type AntiCsrfFunctionOrString struct {
Expand Down
6 changes: 6 additions & 0 deletions recipe/session/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
useDynamicSigningKey = *config.UseDynamicAccessTokenSigningKey
}

var jwksRefreshIntervalSec uint64 = 4 * 3600 // 4 hours
if config != nil && config.JWKSRefreshIntervalSec != nil {
jwksRefreshIntervalSec = *config.JWKSRefreshIntervalSec
}

typeNormalisedInput := sessmodels.TypeNormalisedInput{
RefreshTokenPath: appInfo.APIBasePath.AppendPath(refreshAPIPath),
CookieDomain: cookieDomain,
Expand All @@ -233,6 +238,7 @@ func ValidateAndNormaliseUserInput(appInfo supertokens.NormalisedAppinfo, config
AntiCsrfFunctionOrString: antiCsrfFunctionOrString,
ExposeAccessTokenToFrontendInCookieBasedAuth: config.ExposeAccessTokenToFrontendInCookieBasedAuth,
UseDynamicAccessTokenSigningKey: useDynamicSigningKey,
JWKSRefreshIntervalSec: jwksRefreshIntervalSec,
ErrorHandlers: errorHandlers,
GetTokenTransferMethod: config.GetTokenTransferMethod,
Override: sessmodels.OverrideStruct{
Expand Down

0 comments on commit 29eef08

Please sign in to comment.