diff --git a/CHANGELOG.md b/CHANGELOG.md index 68f58868..dea3632e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.15.0] - 2023-09-26 +- Adds Twitter/X as a default provider to the third party recipe - Added a `Cache-Control` header to `/jwt/jwks.json` (`GetJWKSGET`) - Added `ValidityInSeconds` to the return value of the overrideable `GetJWKS` function. - This can be used to control the `Cache-Control` header mentioned above. diff --git a/recipe/thirdparty/providers/config_utils.go b/recipe/thirdparty/providers/config_utils.go index 0be068c1..b40cf462 100644 --- a/recipe/thirdparty/providers/config_utils.go +++ b/recipe/thirdparty/providers/config_utils.go @@ -75,6 +75,8 @@ func createProvider(input tpmodels.ProviderInput) *tpmodels.TypeProvider { return Linkedin(input) } else if strings.HasPrefix(input.Config.ThirdPartyId, "boxy-saml") { return BoxySaml(input) + } else if strings.HasPrefix(input.Config.ThirdPartyId, "twitter") { + return Twitter(input) } return NewProvider(input) diff --git a/recipe/thirdparty/providers/twitter.go b/recipe/thirdparty/providers/twitter.go new file mode 100644 index 00000000..cb641c72 --- /dev/null +++ b/recipe/thirdparty/providers/twitter.go @@ -0,0 +1,90 @@ +package providers + +import ( + "encoding/base64" + "github.com/supertokens/supertokens-golang/recipe/thirdparty/tpmodels" + "github.com/supertokens/supertokens-golang/supertokens" +) + +func Twitter(input tpmodels.ProviderInput) *tpmodels.TypeProvider { + if input.Config.Name == "" { + input.Config.Name = "Twitter" + } + + if input.Config.AuthorizationEndpoint == "" { + input.Config.AuthorizationEndpoint = "https://twitter.com/i/oauth2/authorize" + } + + if input.Config.TokenEndpoint == "" { + input.Config.TokenEndpoint = "https://api.twitter.com/2/oauth2/token" + } + + if input.Config.UserInfoEndpoint == "" { + input.Config.UserInfoEndpoint = "https://api.twitter.com/2/users/me" + } + + if input.Config.RequireEmail == nil { + False := false + input.Config.RequireEmail = &False + } + + if input.Config.UserInfoMap.FromUserInfoAPI.UserId == "" { + input.Config.UserInfoMap.FromUserInfoAPI.UserId = "data.id" + } + + oOverride := input.Override + + input.Override = func(originalImplementation *tpmodels.TypeProvider) *tpmodels.TypeProvider { + oGetConfig := originalImplementation.GetConfigForClientType + originalImplementation.GetConfigForClientType = func(clientType *string, userContext supertokens.UserContext) (tpmodels.ProviderConfigForClientType, error) { + config, err := oGetConfig(clientType, userContext) + if err != nil { + return tpmodels.ProviderConfigForClientType{}, err + } + + if len(config.Scope) == 0 { + config.Scope = []string{"users.read", "tweet.read"} + } + + if config.ForcePKCE == nil { + True := true + config.ForcePKCE = &True + } + + return config, nil + } + + originalImplementation.ExchangeAuthCodeForOAuthTokens = func(redirectURIInfo tpmodels.TypeRedirectURIInfo, userContext supertokens.UserContext) (tpmodels.TypeOAuthTokens, error) { + basicAuthToken := base64.StdEncoding.EncodeToString([]byte(originalImplementation.Config.ClientID + ":" + originalImplementation.Config.ClientSecret)) + twitterOauthParams := map[string]interface{}{} + + if originalImplementation.Config.TokenEndpointBodyParams != nil { + twitterOauthParams = originalImplementation.Config.TokenEndpointBodyParams + } + + codeVerifier := "" + + if redirectURIInfo.PKCECodeVerifier != nil { + codeVerifier = *redirectURIInfo.PKCECodeVerifier + } + + twitterOauthParams["grant_type"] = "authorization_code" + twitterOauthParams["client_id"] = originalImplementation.Config.ClientID + twitterOauthParams["code_verifier"] = codeVerifier + twitterOauthParams["redirect_uri"] = redirectURIInfo.RedirectURIOnProviderDashboard + twitterOauthParams["code"] = redirectURIInfo.RedirectURIQueryParams["code"] + + return doPostRequest(originalImplementation.Config.TokenEndpoint, twitterOauthParams, map[string]interface{}{ + "Authorization": "Basic " + basicAuthToken, + }) + } + + if oOverride != nil { + originalImplementation = oOverride(originalImplementation) + } + + return originalImplementation + } + + return NewProvider(input) +}