From cb383a93c812973cc02626a0a993a5210128bd7e Mon Sep 17 00:00:00 2001 From: rishabhpoddar Date: Fri, 12 Apr 2024 12:48:50 +0530 Subject: [PATCH] makes change to middleware routing logic --- supertokens/supertokens.go | 120 +++++++++++++++++++++++-------------- 1 file changed, 76 insertions(+), 44 deletions(-) diff --git a/supertokens/supertokens.go b/supertokens/supertokens.go index b3bc166e..b5d5f825 100644 --- a/supertokens/supertokens.go +++ b/supertokens/supertokens.go @@ -181,41 +181,69 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { requestRID = "" } if requestRID != "" { - var matchedRecipe *RecipeModule + var matchedRecipes []RecipeModule = []RecipeModule{} for _, recipeModule := range s.RecipeModules { LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID()) if recipeModule.GetRecipeID() == requestRID { - matchedRecipe = &recipeModule - break + matchedRecipes = append(matchedRecipes, recipeModule) + } else if requestRID == "thirdpartyemailpassword" { + if recipeModule.GetRecipeID() == "thirdparty" || + recipeModule.GetRecipeID() == "emailpassword" { + matchedRecipes = append(matchedRecipes, recipeModule) + } + } else if requestRID == "thirdpartypasswordless" { + if recipeModule.GetRecipeID() == "thirdparty" || + recipeModule.GetRecipeID() == "passwordless" { + matchedRecipes = append(matchedRecipes, recipeModule) + } } } - if matchedRecipe == nil { - LogDebugMessage("middleware: Not handling because no recipe matched") - theirHandler.ServeHTTP(dw, r) + if len(matchedRecipes) == 0 { + LogDebugMessage("middleware: Not handling because no recipe matched. Trying without rid") + s.middlewareHelperHandleWithoutRid(path, method, userContext, theirHandler, dw, r) return } - LogDebugMessage("middleware: Matched with recipe ID: " + matchedRecipe.GetRecipeID()) + for _, matchedRecipe := range matchedRecipes { + LogDebugMessage("middleware: Matched with recipe IDs: " + matchedRecipe.GetRecipeID()) + } - id, tenantId, err := matchedRecipe.ReturnAPIIdIfCanHandleRequest(path, method, userContext) + var id *string = nil + var finalTenantId *string = nil + var finalMatchedRecipe *RecipeModule = nil - if err != nil { - err = s.errorHandler(err, r, dw, userContext) - if err != nil && !dw.IsDone() { - s.OnSuperTokensAPIError(err, r, dw) + for _, matchedRecipe := range matchedRecipes { + currId, currTenantId, err := matchedRecipe.ReturnAPIIdIfCanHandleRequest(path, method, userContext) + if err != nil { + err = s.errorHandler(err, r, dw, userContext) + if err != nil && !dw.IsDone() { + s.OnSuperTokensAPIError(err, r, dw) + } + return + } + + if currId != nil { + if id != nil { + if !dw.IsDone() { + s.OnSuperTokensAPIError(errors.New("Two recipes have matched the same API path and method! This is a bug in the SDK. Please contact support."), r, dw) + } + return + } else { + id = currId + finalTenantId = &currTenantId + finalMatchedRecipe = &matchedRecipe + } } - return } - if id == nil { - LogDebugMessage("middleware: Not handling because recipe doesn't handle request path or method. Request path: " + path.GetAsStringDangerous() + ", request method: " + method) - theirHandler.ServeHTTP(dw, r) + if id == nil || finalTenantId == nil || finalMatchedRecipe == nil { + s.middlewareHelperHandleWithoutRid(path, method, userContext, theirHandler, dw, r) return } LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id) - tenantId, err = GetTenantIdFuncFromUsingMultitenancyRecipe(tenantId, userContext) + var tenantId, err = GetTenantIdFuncFromUsingMultitenancyRecipe(*finalTenantId, userContext) if err != nil { err = s.errorHandler(err, r, dw, userContext) if err != nil && !dw.IsDone() { @@ -224,7 +252,7 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { return } - apiErr := matchedRecipe.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) + apiErr := finalMatchedRecipe.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) if apiErr != nil { apiErr = s.errorHandler(apiErr, r, dw, userContext) if apiErr != nil && !dw.IsDone() { @@ -234,36 +262,40 @@ func (s *superTokens) middleware(theirHandler http.Handler) http.Handler { } LogDebugMessage("middleware: Ended") } else { - for _, recipeModule := range s.RecipeModules { - id, tenantId, err := recipeModule.ReturnAPIIdIfCanHandleRequest(path, method, userContext) - LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID()) - if err != nil { - err = s.errorHandler(err, r, dw, userContext) - if err != nil && !dw.IsDone() { - s.OnSuperTokensAPIError(err, r, dw) - } - return - } + s.middlewareHelperHandleWithoutRid(path, method, userContext, theirHandler, dw, r) + } + }) +} - if id != nil { - LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id) - err := recipeModule.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) - if err != nil { - err = s.errorHandler(err, r, dw, userContext) - if err != nil && !dw.IsDone() { - s.OnSuperTokensAPIError(err, r, dw) - } - } else { - LogDebugMessage("middleware: Ended") - } - return - } +func (s *superTokens) middlewareHelperHandleWithoutRid(path NormalisedURLPath, method string, userContext *map[string]interface{}, theirHandler http.Handler, dw DoneWriter, r *http.Request) { + for _, recipeModule := range s.RecipeModules { + id, tenantId, err := recipeModule.ReturnAPIIdIfCanHandleRequest(path, method, userContext) + LogDebugMessage("middleware: Checking recipe ID for match: " + recipeModule.GetRecipeID()) + if err != nil { + err = s.errorHandler(err, r, dw, userContext) + if err != nil && !dw.IsDone() { + s.OnSuperTokensAPIError(err, r, dw) } + return + } - LogDebugMessage("middleware: Not handling because no recipe matched") - theirHandler.ServeHTTP(dw, r) + if id != nil { + LogDebugMessage("middleware: Request being handled by recipe. ID is: " + *id) + err := recipeModule.HandleAPIRequest(*id, tenantId, r, dw, theirHandler.ServeHTTP, path, method, userContext) + if err != nil { + err = s.errorHandler(err, r, dw, userContext) + if err != nil && !dw.IsDone() { + s.OnSuperTokensAPIError(err, r, dw) + } + } else { + LogDebugMessage("middleware: Ended") + } + return } - }) + } + + LogDebugMessage("middleware: Not handling because no recipe matched") + theirHandler.ServeHTTP(dw, r) } func (s *superTokens) getAllCORSHeaders() []string {