diff --git a/recipe/dashboard/recipe.go b/recipe/dashboard/recipe.go index 135ba154..d5502461 100644 --- a/recipe/dashboard/recipe.go +++ b/recipe/dashboard/recipe.go @@ -350,7 +350,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/emailpassword/recipe.go b/recipe/emailpassword/recipe.go index f37f9d19..0d501ccc 100644 --- a/recipe/emailpassword/recipe.go +++ b/recipe/emailpassword/recipe.go @@ -181,7 +181,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { if defaultErrors.As(err, &errors.FieldError{}) { errs := err.(errors.FieldError) return true, supertokens.Send200Response(res, map[string]interface{}{ diff --git a/recipe/emailverification/recipe.go b/recipe/emailverification/recipe.go index e9889302..f1c02bb6 100644 --- a/recipe/emailverification/recipe.go +++ b/recipe/emailverification/recipe.go @@ -199,7 +199,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/jwt/recipe.go b/recipe/jwt/recipe.go index af7a6053..a1668da0 100644 --- a/recipe/jwt/recipe.go +++ b/recipe/jwt/recipe.go @@ -107,7 +107,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/multitenancy/recipe.go b/recipe/multitenancy/recipe.go index 74fbd89d..7c514829 100644 --- a/recipe/multitenancy/recipe.go +++ b/recipe/multitenancy/recipe.go @@ -144,7 +144,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/openid/recipe.go b/recipe/openid/recipe.go index 1594b305..c87632cd 100644 --- a/recipe/openid/recipe.go +++ b/recipe/openid/recipe.go @@ -127,8 +127,8 @@ func (r *Recipe) getAllCORSHeaders() []string { return r.JwtRecipe.RecipeModule.GetAllCORSHeaders() } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - return r.JwtRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + return r.JwtRecipe.RecipeModule.HandleError(err, req, res, userContext) } func ResetForTest() { diff --git a/recipe/passwordless/recipe.go b/recipe/passwordless/recipe.go index 91eaf853..244a80d3 100644 --- a/recipe/passwordless/recipe.go +++ b/recipe/passwordless/recipe.go @@ -188,7 +188,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/session/config_test.go b/recipe/session/config_test.go index ab68b6e1..b7440f83 100644 --- a/recipe/session/config_test.go +++ b/recipe/session/config_test.go @@ -437,8 +437,10 @@ func TestSuperTokensInitWithNoneLaxFalseSessionConfigResults(t *testing.T) { if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "NONE") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "NONE") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false) } @@ -475,8 +477,10 @@ func TestSuperTokensInitWithCustomHeaderLaxTrueSessionConfigResults(t *testing.T if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -514,8 +518,10 @@ func TestSuperTokensInitWithCustomHeaderLaxFalseSessionConfigResults(t *testing. if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "lax") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false) } @@ -548,8 +554,10 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithNormalWe if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "none") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -582,8 +590,10 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithLocalWeb if err != nil { t.Error(err.Error()) } - assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER") - assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none") + assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER") + cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "none") assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true) } @@ -619,11 +629,11 @@ func TestSuperTokensWithAntiCSRFNone(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() if err != nil { t.Error(err.Error()) } - assert.Equal(t, singletoneSessionRecipeInstance.Config.AntiCsrf, "NONE") + assert.Equal(t, singletonSessionRecipeInstance.Config.AntiCsrfFunctionOrString.StrValue, "NONE") } func TestSuperTokensWithAntiCSRFRandom(t *testing.T) { @@ -737,12 +747,14 @@ func TestSuperTokensForTheDefaultCookieValues(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() if err != nil { t.Error(err.Error()) } - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true) - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "none") + assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true) + cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "none") } func TestSuperTokensInitWithWrongConfigSchema(t *testing.T) { @@ -867,15 +879,17 @@ func TestSuperTokensDefaultCookieConfig(t *testing.T) { if err != nil { t.Error(err.Error()) } - singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError() + singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError() if err != nil { t.Error(err.Error()) } - assert.Nil(t, singletoneSessionRecipeInstance.Config.CookieDomain) - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "lax") - assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true) - assert.Equal(t, singletoneSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh") - assert.Equal(t, singletoneSessionRecipeInstance.Config.SessionExpiredStatusCode, 401) + assert.Nil(t, singletonSessionRecipeInstance.Config.CookieDomain) + cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSite, "lax") + assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true) + assert.Equal(t, singletonSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh") + assert.Equal(t, singletonSessionRecipeInstance.Config.SessionExpiredStatusCode, 401) } func TestSuperTokensInitWithAPIGateWayPath(t *testing.T) { @@ -1256,7 +1270,9 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) { } assert.True(t, recipe.Config.CookieDomain == nil) - assert.Equal(t, recipe.Config.CookieSameSite, "none") + cookieSameSiteValue, err := recipe.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSiteValue, "none") assert.True(t, recipe.Config.CookieSecure) resetAll() @@ -1293,6 +1309,8 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) { } assert.True(t, recipe.Config.CookieDomain == nil) - assert.Equal(t, recipe.Config.CookieSameSite, "lax") + cookieSameSiteValue, err = recipe.Config.GetCookieSameSite(nil, nil) + assert.True(t, err != nil) + assert.Equal(t, cookieSameSiteValue, "lax") assert.False(t, recipe.Config.CookieSecure) } diff --git a/recipe/session/middleware.go b/recipe/session/middleware.go index c6125e5d..a813ee66 100644 --- a/recipe/session/middleware.go +++ b/recipe/session/middleware.go @@ -36,7 +36,7 @@ func VerifySessionHelper(recipeInstance Recipe, options *sessmodels.VerifySessio RecipeImplementation: recipeInstance.RecipeImpl, }, userContext) if err != nil { - err = supertokens.ErrorHandler(err, r, dw) + err = supertokens.ErrorHandler(err, r, dw, userContext) if err != nil { recipeInstance.RecipeModule.OnSuperTokensAPIError(err, r, dw) } diff --git a/recipe/thirdparty/recipe.go b/recipe/thirdparty/recipe.go index d564ae44..82db5460 100644 --- a/recipe/thirdparty/recipe.go +++ b/recipe/thirdparty/recipe.go @@ -155,7 +155,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { if errors.As(err, &tperrors.ClientTypeNotFoundError{}) { supertokens.SendNon200ResponseWithMessage(res, err.Error(), 400) return true, nil diff --git a/recipe/thirdpartyemailpassword/recipe.go b/recipe/thirdpartyemailpassword/recipe.go index 3743dd1e..382f385b 100644 --- a/recipe/thirdpartyemailpassword/recipe.go +++ b/recipe/thirdpartyemailpassword/recipe.go @@ -201,13 +201,13 @@ func (r *Recipe) getAllCORSHeaders() []string { return corsHeaders } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } if r.thirdPartyRecipe != nil { - handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res) + handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } diff --git a/recipe/thirdpartypasswordless/recipe.go b/recipe/thirdpartypasswordless/recipe.go index 8b03c1fa..e4566516 100644 --- a/recipe/thirdpartypasswordless/recipe.go +++ b/recipe/thirdpartypasswordless/recipe.go @@ -204,13 +204,13 @@ func (r *Recipe) getAllCORSHeaders() []string { return corsHeaders } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { - handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res) +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { + handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } if r.thirdPartyRecipe != nil { - handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res) + handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext) if err != nil || handleError { return handleError, err } diff --git a/recipe/usermetadata/recipe.go b/recipe/usermetadata/recipe.go index 2fb0cec2..4923e017 100644 --- a/recipe/usermetadata/recipe.go +++ b/recipe/usermetadata/recipe.go @@ -86,7 +86,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/recipe/userroles/recipe.go b/recipe/userroles/recipe.go index c49af9ba..2b961a6f 100644 --- a/recipe/userroles/recipe.go +++ b/recipe/userroles/recipe.go @@ -106,7 +106,7 @@ func (r *Recipe) getAllCORSHeaders() []string { return []string{} } -func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) { +func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) { return false, nil } diff --git a/supertokens/main.go b/supertokens/main.go index 338919b1..df5a8a8b 100644 --- a/supertokens/main.go +++ b/supertokens/main.go @@ -41,12 +41,15 @@ func Middleware(theirHandler http.Handler) http.Handler { return instance.middleware(theirHandler) } -func ErrorHandler(err error, req *http.Request, res http.ResponseWriter) error { +func ErrorHandler(err error, req *http.Request, res http.ResponseWriter, userContext ...UserContext) error { instance, instanceErr := GetInstanceOrThrowError() if instanceErr != nil { return instanceErr } - return instance.errorHandler(err, req, res) + if len(userContext) == 0 { + userContext = append(userContext, &map[string]interface{}{}) + } + return instance.errorHandler(err, req, res, userContext[0]) } func GetAllCORSHeaders() []string { diff --git a/supertokens/supertokens.go b/supertokens/supertokens.go index c282791a..b3bc166e 100644 --- a/supertokens/supertokens.go +++ b/supertokens/supertokens.go @@ -160,7 +160,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { userContext := MakeDefaultUserContextFromAPI(r) reqURL, err := NewNormalisedURLPath(r.URL.Path) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -200,7 +200,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { id, tenantId, err := matchedRecipe.ReturnAPIIdIfCanHandleRequest(path, method, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -217,7 +217,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { tenantId, err = GetTenantIdFuncFromUsingMultitenancyRecipe(tenantId, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -226,7 +226,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { apiErr := matchedRecipe.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) if apiErr != nil { - apiErr = s.errorHandler(apiErr, r, dw) + apiErr = s.errorHandler(apiErr, r, dw, userContext) if apiErr != nil && !dw.IsDone() { s.OnSuperTokensAPIError(apiErr, r, dw) } @@ -238,7 +238,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { id, tenantId, err := recipeModule.ReturnAPIIdIfCanHandleRequest(path, method, userContext) LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID()) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -249,7 +249,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id) err := recipeModule.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) if err != nil { - err = s.errorHandler(err, r, dw) + err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { s.OnSuperTokensAPIError(err, r, dw) } @@ -281,7 +281,7 @@ func (s *superTokens) getAllCORSHeaders() []string { return headers } -func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter) error { +func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter, userContext UserContext) error { LogDebugMessage("errorHandler: Started") if errors.As(originalError, &BadInputError{}) { LogDebugMessage("errorHandler: Sending 400 status code response") @@ -299,7 +299,7 @@ func (s *superTokens) errorHandler(originalError error, req *http.Request, res h LogDebugMessage("errorHandler: Checking recipe for match: " + recipe.recipeID) if recipe.HandleError != nil { LogDebugMessage("errorHandler: Matched with recipeId: " + recipe.recipeID) - handled, err := recipe.HandleError(originalError, req, res) + handled, err := recipe.HandleError(originalError, req, res, userContext) if err != nil { return err }