Skip to content

Commit

Permalink
completes all lib code changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabhpoddar committed Nov 23, 2023
1 parent ace4b09 commit fc0aeb7
Show file tree
Hide file tree
Showing 16 changed files with 72 additions and 51 deletions.
2 changes: 1 addition & 1 deletion recipe/dashboard/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
2 changes: 1 addition & 1 deletion recipe/emailpassword/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
if defaultErrors.As(err, &errors.FieldError{}) {
errs := err.(errors.FieldError)
return true, supertokens.Send200Response(res, map[string]interface{}{
Expand Down
2 changes: 1 addition & 1 deletion recipe/emailverification/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
2 changes: 1 addition & 1 deletion recipe/jwt/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
2 changes: 1 addition & 1 deletion recipe/multitenancy/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
4 changes: 2 additions & 2 deletions recipe/openid/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ func (r *Recipe) getAllCORSHeaders() []string {
return r.JwtRecipe.RecipeModule.GetAllCORSHeaders()
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
return r.JwtRecipe.RecipeModule.HandleError(err, req, res)
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return r.JwtRecipe.RecipeModule.HandleError(err, req, res, userContext)
}

func ResetForTest() {
Expand Down
2 changes: 1 addition & 1 deletion recipe/passwordless/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
64 changes: 41 additions & 23 deletions recipe/session/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,10 @@ func TestSuperTokensInitWithNoneLaxFalseSessionConfigResults(t *testing.T) {
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "NONE")
assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "NONE")
cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false)
}

Expand Down Expand Up @@ -475,8 +477,10 @@ func TestSuperTokensInitWithCustomHeaderLaxTrueSessionConfigResults(t *testing.T
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER")
assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER")
cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true)
}

Expand Down Expand Up @@ -514,8 +518,10 @@ func TestSuperTokensInitWithCustomHeaderLaxFalseSessionConfigResults(t *testing.
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER")
assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER")
cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "lax")
assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, false)
}

Expand Down Expand Up @@ -548,8 +554,10 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithNormalWe
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER")
assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none")
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER")
cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "none")
assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true)
}

Expand Down Expand Up @@ -582,8 +590,10 @@ func TestSuperTokensInitWithCustomHeaderNoneTrueSessionConfigResultsWithLocalWeb
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrf, "VIA_CUSTOM_HEADER")
assert.Equal(t, sessionSingletonInstance.Config.CookieSameSite, "none")
assert.Equal(t, sessionSingletonInstance.Config.AntiCsrfFunctionOrString.StrValue, "VIA_CUSTOM_HEADER")
cookieSameSite, err := sessionSingletonInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "none")
assert.Equal(t, sessionSingletonInstance.Config.CookieSecure, true)
}

Expand Down Expand Up @@ -619,11 +629,11 @@ func TestSuperTokensWithAntiCSRFNone(t *testing.T) {
if err != nil {
t.Error(err.Error())
}
singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, singletoneSessionRecipeInstance.Config.AntiCsrf, "NONE")
assert.Equal(t, singletonSessionRecipeInstance.Config.AntiCsrfFunctionOrString.StrValue, "NONE")
}

func TestSuperTokensWithAntiCSRFRandom(t *testing.T) {
Expand Down Expand Up @@ -737,12 +747,14 @@ func TestSuperTokensForTheDefaultCookieValues(t *testing.T) {
if err != nil {
t.Error(err.Error())
}
singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
t.Error(err.Error())
}
assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true)
assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "none")
assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true)
cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "none")
}

func TestSuperTokensInitWithWrongConfigSchema(t *testing.T) {
Expand Down Expand Up @@ -867,15 +879,17 @@ func TestSuperTokensDefaultCookieConfig(t *testing.T) {
if err != nil {
t.Error(err.Error())
}
singletoneSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
singletonSessionRecipeInstance, err := getRecipeInstanceOrThrowError()
if err != nil {
t.Error(err.Error())
}
assert.Nil(t, singletoneSessionRecipeInstance.Config.CookieDomain)
assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSameSite, "lax")
assert.Equal(t, singletoneSessionRecipeInstance.Config.CookieSecure, true)
assert.Equal(t, singletoneSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh")
assert.Equal(t, singletoneSessionRecipeInstance.Config.SessionExpiredStatusCode, 401)
assert.Nil(t, singletonSessionRecipeInstance.Config.CookieDomain)
cookieSameSite, err := singletonSessionRecipeInstance.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSite, "lax")
assert.Equal(t, singletonSessionRecipeInstance.Config.CookieSecure, true)
assert.Equal(t, singletonSessionRecipeInstance.Config.RefreshTokenPath.GetAsStringDangerous(), "/auth/session/refresh")
assert.Equal(t, singletonSessionRecipeInstance.Config.SessionExpiredStatusCode, 401)
}

func TestSuperTokensInitWithAPIGateWayPath(t *testing.T) {
Expand Down Expand Up @@ -1256,7 +1270,9 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) {
}

assert.True(t, recipe.Config.CookieDomain == nil)
assert.Equal(t, recipe.Config.CookieSameSite, "none")
cookieSameSiteValue, err := recipe.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSiteValue, "none")
assert.True(t, recipe.Config.CookieSecure)

resetAll()
Expand Down Expand Up @@ -1293,6 +1309,8 @@ func TestCookieSameSiteWithEC2PublicURL(t *testing.T) {
}

assert.True(t, recipe.Config.CookieDomain == nil)
assert.Equal(t, recipe.Config.CookieSameSite, "lax")
cookieSameSiteValue, err = recipe.Config.GetCookieSameSite(nil, nil)
assert.True(t, err != nil)
assert.Equal(t, cookieSameSiteValue, "lax")
assert.False(t, recipe.Config.CookieSecure)
}
2 changes: 1 addition & 1 deletion recipe/session/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func VerifySessionHelper(recipeInstance Recipe, options *sessmodels.VerifySessio
RecipeImplementation: recipeInstance.RecipeImpl,
}, userContext)
if err != nil {
err = supertokens.ErrorHandler(err, r, dw)
err = supertokens.ErrorHandler(err, r, dw, userContext)
if err != nil {
recipeInstance.RecipeModule.OnSuperTokensAPIError(err, r, dw)
}
Expand Down
2 changes: 1 addition & 1 deletion recipe/thirdparty/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
if errors.As(err, &tperrors.ClientTypeNotFoundError{}) {
supertokens.SendNon200ResponseWithMessage(res, err.Error(), 400)
return true, nil
Expand Down
6 changes: 3 additions & 3 deletions recipe/thirdpartyemailpassword/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ func (r *Recipe) getAllCORSHeaders() []string {
return corsHeaders
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res)
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
handleError, err := r.emailPasswordRecipe.RecipeModule.HandleError(err, req, res, userContext)
if err != nil || handleError {
return handleError, err
}
if r.thirdPartyRecipe != nil {
handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res)
handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext)
if err != nil || handleError {
return handleError, err
}
Expand Down
6 changes: 3 additions & 3 deletions recipe/thirdpartypasswordless/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,13 @@ func (r *Recipe) getAllCORSHeaders() []string {
return corsHeaders
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res)
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
handleError, err := r.passwordlessRecipe.RecipeModule.HandleError(err, req, res, userContext)
if err != nil || handleError {
return handleError, err
}
if r.thirdPartyRecipe != nil {
handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res)
handleError, err = r.thirdPartyRecipe.RecipeModule.HandleError(err, req, res, userContext)
if err != nil || handleError {
return handleError, err
}
Expand Down
2 changes: 1 addition & 1 deletion recipe/usermetadata/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
2 changes: 1 addition & 1 deletion recipe/userroles/recipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (r *Recipe) getAllCORSHeaders() []string {
return []string{}
}

func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter) (bool, error) {
func (r *Recipe) handleError(err error, req *http.Request, res http.ResponseWriter, userContext supertokens.UserContext) (bool, error) {
return false, nil
}

Expand Down
7 changes: 5 additions & 2 deletions supertokens/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ func Middleware(theirHandler http.Handler) http.Handler {
return instance.middleware(theirHandler)
}

func ErrorHandler(err error, req *http.Request, res http.ResponseWriter) error {
func ErrorHandler(err error, req *http.Request, res http.ResponseWriter, userContext ...UserContext) error {
instance, instanceErr := GetInstanceOrThrowError()
if instanceErr != nil {
return instanceErr
}
return instance.errorHandler(err, req, res)
if len(userContext) == 0 {
userContext = append(userContext, &map[string]interface{}{})
}
return instance.errorHandler(err, req, res, userContext[0])
}

func GetAllCORSHeaders() []string {
Expand Down
16 changes: 8 additions & 8 deletions supertokens/supertokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {
userContext := MakeDefaultUserContextFromAPI(r)
reqURL, err := NewNormalisedURLPath(r.URL.Path)
if err != nil {
err = s.errorHandler(err, r, dw)
err = s.errorHandler(err, r, dw, userContext)
if err != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(err, r, dw)
}
Expand Down Expand Up @@ -200,7 +200,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {
id, tenantId, err := matchedRecipe.ReturnAPIIdIfCanHandleRequest(path, method, userContext)

if err != nil {
err = s.errorHandler(err, r, dw)
err = s.errorHandler(err, r, dw, userContext)
if err != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(err, r, dw)
}
Expand All @@ -217,7 +217,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {

tenantId, err = GetTenantIdFuncFromUsingMultitenancyRecipe(tenantId, userContext)
if err != nil {
err = s.errorHandler(err, r, dw)
err = s.errorHandler(err, r, dw, userContext)
if err != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(err, r, dw)
}
Expand All @@ -226,7 +226,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {

apiErr := matchedRecipe.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext)
if apiErr != nil {
apiErr = s.errorHandler(apiErr, r, dw)
apiErr = s.errorHandler(apiErr, r, dw, userContext)
if apiErr != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(apiErr, r, dw)
}
Expand All @@ -238,7 +238,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {
id, tenantId, err := recipeModule.ReturnAPIIdIfCanHandleRequest(path, method, userContext)
LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID())
if err != nil {
err = s.errorHandler(err, r, dw)
err = s.errorHandler(err, r, dw, userContext)
if err != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(err, r, dw)
}
Expand All @@ -249,7 +249,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler {
LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id)
err := recipeModule.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext)
if err != nil {
err = s.errorHandler(err, r, dw)
err = s.errorHandler(err, r, dw, userContext)
if err != nil && !dw.IsDone() {
s.OnSuperTokensAPIError(err, r, dw)
}
Expand Down Expand Up @@ -281,7 +281,7 @@ func (s *superTokens) getAllCORSHeaders() []string {
return headers
}

func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter) error {
func (s *superTokens) errorHandler(originalError error, req *http.Request, res http.ResponseWriter, userContext UserContext) error {
LogDebugMessage("errorHandler: Started")
if errors.As(originalError, &BadInputError{}) {
LogDebugMessage("errorHandler: Sending 400 status code response")
Expand All @@ -299,7 +299,7 @@ func (s *superTokens) errorHandler(originalError error, req *http.Request, res h
LogDebugMessage("errorHandler: Checking recipe for match: " + recipe.recipeID)
if recipe.HandleError != nil {
LogDebugMessage("errorHandler: Matched with recipeId: " + recipe.recipeID)
handled, err := recipe.HandleError(originalError, req, res)
handled, err := recipe.HandleError(originalError, req, res, userContext)
if err != nil {
return err
}
Expand Down

0 comments on commit fc0aeb7

Please sign in to comment.