diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f031e8e..1ca35438 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] +- `session.CreateNewSession` now defaults to the value of the `st-auth-mode` header (if available) if the configured `config.GetTokenTransferMethod` returns `any`. ## [0.17.5] - 2024-03-14 - Adds a type uint64 to the `accessTokenCookiesExpiryDurationMillis` local variable in `recipe/session/utils.go`. It also removes the redundant `uint64` type forcing needed because of the untyped variable. diff --git a/recipe/emailpassword/authMode_test.go b/recipe/emailpassword/authMode_test.go index 17146498..10a7c4ea 100644 --- a/recipe/emailpassword/authMode_test.go +++ b/recipe/emailpassword/authMode_test.go @@ -224,15 +224,43 @@ func TestWithGetTokenTransferMethodProvidedCreateNewSessionWithShouldUseHeaderIf defer testServer.Close() setupRoutesForTest(t, mux) - resp := createNewSession(t, testServer.URL, nil, nil, nil, nil) + t.Run("no st-auth-mode", func(t *testing.T) { + resp := createNewSession(t, testServer.URL, nil, nil, nil, nil) + + assert.Equal(t, resp["sAccessToken"], "-not-present-") + assert.Equal(t, resp["sRefreshToken"], "-not-present-") + assert.Equal(t, resp["antiCsrf"], "-not-present-") + assert.NotEmpty(t, resp["accessTokenFromHeader"]) + assert.NotEqual(t, resp["accessTokenFromHeader"], "-not-present-") + assert.NotEmpty(t, resp["refreshTokenFromHeader"]) + assert.NotEqual(t, resp["refreshTokenFromHeader"], "-not-present-") + }) - assert.Equal(t, resp["sAccessToken"], "-not-present-") - assert.Equal(t, resp["sRefreshToken"], "-not-present-") - assert.Equal(t, resp["antiCsrf"], "-not-present-") - assert.NotEmpty(t, resp["accessTokenFromHeader"]) - assert.NotEqual(t, resp["accessTokenFromHeader"], "-not-present-") - assert.NotEmpty(t, resp["refreshTokenFromHeader"]) - assert.NotEqual(t, resp["refreshTokenFromHeader"], "-not-present-") + t.Run("st-auth-mode is cookie", func(t *testing.T) { + authMode := string(sessmodels.CookieTransferMethod) + resp := createNewSession(t, testServer.URL, &authMode, nil, nil, nil) + + assert.NotEqual(t, resp["sAccessToken"], "-not-present-") + assert.NotEqual(t, resp["sRefreshToken"], "-not-present-") + assert.NotEqual(t, resp["antiCsrf"], "-not-present-") + assert.NotEmpty(t, resp["accessTokenFromHeader"]) + assert.Equal(t, resp["accessTokenFromHeader"], "-not-present-") + assert.NotEmpty(t, resp["refreshTokenFromHeader"]) + assert.Equal(t, resp["refreshTokenFromHeader"], "-not-present-") + }) + + t.Run("st-auth-mode is header", func(t *testing.T) { + authMode := string(sessmodels.HeaderTransferMethod) + resp := createNewSession(t, testServer.URL, &authMode, nil, nil, nil) + + assert.Equal(t, resp["sAccessToken"], "-not-present-") + assert.Equal(t, resp["sRefreshToken"], "-not-present-") + assert.Equal(t, resp["antiCsrf"], "-not-present-") + assert.NotEmpty(t, resp["accessTokenFromHeader"]) + assert.NotEqual(t, resp["accessTokenFromHeader"], "-not-present-") + assert.NotEmpty(t, resp["refreshTokenFromHeader"]) + assert.NotEqual(t, resp["refreshTokenFromHeader"], "-not-present-") + }) } func TestWithGetTokenTransferMethodProvidedCreateNewSessionWithShouldUseHeaderIfMethodReturnsHeader(t *testing.T) { diff --git a/recipe/session/sessionRequestFunctions.go b/recipe/session/sessionRequestFunctions.go index c542a804..be796f27 100644 --- a/recipe/session/sessionRequestFunctions.go +++ b/recipe/session/sessionRequestFunctions.go @@ -60,7 +60,12 @@ func CreateNewSessionInRequest(req *http.Request, res http.ResponseWriter, tenan outputTokenTransferMethod := config.GetTokenTransferMethod(req, true, userContext) if outputTokenTransferMethod == sessmodels.AnyTransferMethod { - outputTokenTransferMethod = sessmodels.HeaderTransferMethod + authMode := GetAuthmodeFromHeader(req) + if authMode != nil && *authMode == sessmodels.CookieTransferMethod { + outputTokenTransferMethod = *authMode + } else { + outputTokenTransferMethod = sessmodels.HeaderTransferMethod + } } supertokens.LogDebugMessage(fmt.Sprintf("createNewSession: using transfer method %s", outputTokenTransferMethod))