From c0622ecbf60360ba0c277827b908a466a868f9c6 Mon Sep 17 00:00:00 2001 From: Sattvik Chakravarthy Date: Fri, 26 Jul 2024 16:05:40 +0530 Subject: [PATCH] fix: thirdparty and multitenancy --- test/test-server/go.mod | 2 + test/test-server/main.go | 43 +++++- test/test-server/multitenancy.go | 220 +++++++++++++++++++++++++++++++ test/test-server/thirdparty.go | 131 ++++++++++++++++++ 4 files changed, 394 insertions(+), 2 deletions(-) create mode 100644 test/test-server/multitenancy.go create mode 100644 test/test-server/thirdparty.go diff --git a/test/test-server/go.mod b/test/test-server/go.mod index 91c42f4b..e4e1a44b 100644 --- a/test/test-server/go.mod +++ b/test/test-server/go.mod @@ -10,10 +10,12 @@ require ( require ( github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/derekstavis/go-qs v0.0.0-20180720192143-9eef69e6c4e7 // indirect github.com/golang-jwt/jwt/v5 v5.0.0 // indirect github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/stretchr/testify v1.7.0 // indirect + golang.org/x/crypto v0.2.0 // indirect golang.org/x/net v0.2.0 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df // indirect diff --git a/test/test-server/main.go b/test/test-server/main.go index 459b31c4..eac8aa67 100644 --- a/test/test-server/main.go +++ b/test/test-server/main.go @@ -15,6 +15,8 @@ import ( "github.com/supertokens/supertokens-golang/recipe/emailverification/evmodels" "github.com/supertokens/supertokens-golang/recipe/session" "github.com/supertokens/supertokens-golang/recipe/session/sessmodels" + "github.com/supertokens/supertokens-golang/recipe/thirdparty" + "github.com/supertokens/supertokens-golang/recipe/thirdparty/tpmodels" "github.com/supertokens/supertokens-golang/supertokens" ) @@ -51,10 +53,10 @@ func main() { addSessionRoutes(router) // addAccountLinkingRoutes(router) // addEmailVerificationRoutes(router) - // addMultitenancyRoutes(router) + addMultitenancyRoutes(router) // addPasswordlessRoutes(router) // addMultiFactorAuthRoutes(router) - // addThirdPartyRoutes(router) + addThirdPartyRoutes(router) // addTOTPRoutes(router) // addUserMetadataRoutes(router) @@ -306,6 +308,43 @@ func recipeListFromRecipeConfigs(recipeListMaps []interface{}) []supertokens.Rec } recipeList = append(recipeList, emailverification.Init(recipeConfig)) + } else if recipeItemMap["recipeId"] == "thirdparty" { + var recipeConfigMap map[string]interface{} + err := json.Unmarshal([]byte(recipeItemMap["config"].(string)), &recipeConfigMap) + if err != nil { + log.Printf("Error unmarshaling recipe config: %v", err) + continue + } + recipeConfig := tpmodels.TypeInput{} + + if signInAndUpFeature, ok := recipeConfigMap["signInAndUpFeature"].(map[string]interface{}); ok { + if providers, ok := signInAndUpFeature["providers"].([]interface{}); ok { + for _, provider := range providers { + providerInput := tpmodels.ProviderInput{} + providerMap := provider.(map[string]interface{}) + + if config, ok := providerMap["config"].(map[string]interface{}); ok { + configBytes, err := json.Marshal(config) + if err != nil { + log.Printf("Error marshaling provider config: %v", err) + continue + } + err = json.Unmarshal(configBytes, &providerInput.Config) + if err != nil { + log.Printf("Error unmarshaling provider config: %v", err) + continue + } + } + + if includeInNonPublicTenantsByDefault, ok := providerMap["includeInNonPublicTenantsByDefault"].(bool); ok { + providerInput.IncludeInNonPublicTenantsByDefault = &includeInNonPublicTenantsByDefault + } + recipeConfig.SignInAndUpFeature.Providers = append(recipeConfig.SignInAndUpFeature.Providers, providerInput) + } + } + } + + recipeList = append(recipeList, thirdparty.Init(&recipeConfig)) } } return recipeList diff --git a/test/test-server/multitenancy.go b/test/test-server/multitenancy.go new file mode 100644 index 00000000..3597c287 --- /dev/null +++ b/test/test-server/multitenancy.go @@ -0,0 +1,220 @@ +package main + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/supertokens/supertokens-golang/recipe/multitenancy" + "github.com/supertokens/supertokens-golang/recipe/multitenancy/multitenancymodels" + "github.com/supertokens/supertokens-golang/recipe/thirdparty/tpmodels" + "github.com/supertokens/supertokens-golang/supertokens" +) + +func addMultitenancyRoutes(router *mux.Router) { + router.HandleFunc("/test/multitenancy/createorupdatetenant", createOrUpdateTenantHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/deletetenant", deleteTenantHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/gettenant", getTenantHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/listalltenants", listAllTenantsHandler).Methods("GET") + router.HandleFunc("/test/multitenancy/createorupdatethirdpartyconfig", createOrUpdateThirdPartyConfigHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/deletethirdpartyconfig", deleteThirdPartyConfigHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/associateusertotenant", associateUserToTenantHandler).Methods("POST") + router.HandleFunc("/test/multitenancy/disassociateuserfromtenant", disassociateUserFromTenantHandler).Methods("POST") +} + +func createOrUpdateTenantHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + Config multitenancymodels.TenantConfig `json:"config"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.CreateOrUpdateTenant(body.TenantId, body.Config, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func deleteTenantHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.DeleteTenant(body.TenantId, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func getTenantHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + tenant, err := multitenancy.GetTenant(body.TenantId, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(tenant) +} + +func listAllTenantsHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.ListAllTenants(userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func createOrUpdateThirdPartyConfigHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + Config tpmodels.ProviderConfig `json:"config"` + SkipValidation *bool `json:"skipValidation"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.CreateOrUpdateThirdPartyConfig(body.TenantId, body.Config, body.SkipValidation, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func deleteThirdPartyConfigHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + ThirdPartyId string `json:"thirdPartyId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.DeleteThirdPartyConfig(body.TenantId, body.ThirdPartyId, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func associateUserToTenantHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + UserId string `json:"userId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.AssociateUserToTenant(body.TenantId, body.UserId, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func disassociateUserFromTenantHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + UserId string `json:"userId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := multitenancy.DisassociateUserFromTenant(body.TenantId, body.UserId, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} diff --git a/test/test-server/thirdparty.go b/test/test-server/thirdparty.go new file mode 100644 index 00000000..ff32669e --- /dev/null +++ b/test/test-server/thirdparty.go @@ -0,0 +1,131 @@ +package main + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/supertokens/supertokens-golang/recipe/thirdparty" + "github.com/supertokens/supertokens-golang/supertokens" +) + +func addThirdPartyRoutes(router *mux.Router) { + router.HandleFunc("/test/thirdparty/manuallycreateorupdateuser", manuallyCreateOrUpdateUserHandler).Methods("POST") + router.HandleFunc("/test/thirdparty/getuserbyid", getUserByIDHandler).Methods("POST") + router.HandleFunc("/test/thirdparty/getusersbyemail", getUsersByEmailHandler).Methods("POST") + router.HandleFunc("/test/thirdparty/getuserbythirdpartyinfo", getUserByThirdPartyInfoHandler).Methods("POST") +} + +func manuallyCreateOrUpdateUserHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + ThirdPartyID string `json:"thirdPartyId"` + ThirdPartyUserID string `json:"thirdPartyUserId"` + Email string `json:"email"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if body.TenantId == "" { + body.TenantId = "public" + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + response, err := thirdparty.ManuallyCreateOrUpdateUser(body.TenantId, body.ThirdPartyID, body.ThirdPartyUserID, body.Email, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(response) +} + +func getUserByIDHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + UserID string `json:"userId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + user, err := thirdparty.GetUserByID(body.UserID, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(user) +} + +func getUsersByEmailHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + Email string `json:"email"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if body.TenantId == "" { + body.TenantId = "public" + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + users, err := thirdparty.GetUsersByEmail(body.TenantId, body.Email, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(users) +} + +func getUserByThirdPartyInfoHandler(w http.ResponseWriter, r *http.Request) { + var body struct { + TenantId string `json:"tenantId"` + ThirdPartyID string `json:"thirdPartyId"` + ThirdPartyUserID string `json:"thirdPartyUserId"` + UserContext map[string]interface{} `json:"userContext"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if body.TenantId == "" { + body.TenantId = "public" + } + + var userContext supertokens.UserContext = nil + if body.UserContext != nil { + userContext = &body.UserContext + } + + user, err := thirdparty.GetUserByThirdPartyInfo(body.TenantId, body.ThirdPartyID, body.ThirdPartyUserID, userContext) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(user) +}