From 7276db13bb60b808c4ab8d9c6048660a2f1a78e7 Mon Sep 17 00:00:00 2001 From: Frederic Jahn Date: Wed, 17 Apr 2024 13:55:18 +0200 Subject: [PATCH] fix: fix saml login for existing users (#1434) --- backend/ee/saml/handler.go | 2 +- backend/handler/thirdparty.go | 2 +- backend/thirdparty/linking.go | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backend/ee/saml/handler.go b/backend/ee/saml/handler.go index 692277dab..b158c4578 100644 --- a/backend/ee/saml/handler.go +++ b/backend/ee/saml/handler.go @@ -214,7 +214,7 @@ func (handler *SamlHandler) linkAccount(c echo.Context, redirectTo *url.URL, sta samlError = handler.persister.Transaction(func(tx *pop.Connection) error { userdata := provider.GetUserData(assertionInfo) - linkResult, samlError := thirdparty.LinkAccount(tx, handler.config, handler.persister, userdata, state.Provider) + linkResult, samlError := thirdparty.LinkAccount(tx, handler.config, handler.persister, userdata, state.Provider, true) if samlError != nil { return samlError } diff --git a/backend/handler/thirdparty.go b/backend/handler/thirdparty.go index 6f48146bc..015f8fa95 100644 --- a/backend/handler/thirdparty.go +++ b/backend/handler/thirdparty.go @@ -143,7 +143,7 @@ func (h *ThirdPartyHandler) Callback(c echo.Context) error { return thirdparty.ErrorInvalidRequest("could not retrieve user data from provider").WithCause(terr) } - linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.Name()) + linkingResult, terr := thirdparty.LinkAccount(tx, h.cfg, h.persister, userData, provider.Name(), false) if terr != nil { return terr } diff --git a/backend/thirdparty/linking.go b/backend/thirdparty/linking.go index 7defb951d..d4a4b23ce 100644 --- a/backend/thirdparty/linking.go +++ b/backend/thirdparty/linking.go @@ -19,7 +19,7 @@ const ( getIdentityFailure = "could not get identity" ) -func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string) (*AccountLinkingResult, error) { +func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, isSaml bool) (*AccountLinkingResult, error) { if cfg.Emails.RequireVerification && !userData.Metadata.EmailVerified { return nil, ErrorUnverifiedProviderEmail("third party provider email must be verified") } @@ -38,15 +38,15 @@ func LinkAccount(tx *pop.Connection, cfg *config.Config, p persistence.Persister if user == nil { return signUp(tx, cfg, p, userData, providerName) } else { - return link(tx, cfg, p, userData, providerName, user) + return link(tx, cfg, p, userData, providerName, user, isSaml) } } else { return signIn(tx, cfg, p, userData, identity) } } -func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, user *models.User) (*AccountLinkingResult, error) { - if !cfg.ThirdParty.Providers.Get(providerName).AllowLinking { +func link(tx *pop.Connection, cfg *config.Config, p persistence.Persister, userData *UserData, providerName string, user *models.User, isSaml bool) (*AccountLinkingResult, error) { + if !isSaml && !cfg.ThirdParty.Providers.Get(providerName).AllowLinking { return nil, ErrorUserConflict("third party account linking for existing user with same email disallowed") }