diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py index 5ab3230bc7c3..8dad2ba811ef 100644 --- a/agixt/MagicalAuth.py +++ b/agixt/MagicalAuth.py @@ -2338,6 +2338,7 @@ def sso( provider = str(provider).lower() if provider not in ["amazon", "microsoft", "google", "github", "walmart"]: provider = "microsoft" + sso_data = get_sso_provider(provider=provider, code=code, redirect_uri=referrer) if not sso_data: logging.error(f"Failed to get user data from {provider.capitalize()}.") @@ -2352,6 +2353,7 @@ def sso( status_code=400, detail=f"Failed to get access token from {provider.capitalize()}.", ) + user_data = sso_data.user_info access_token = sso_data.access_token refresh_token = sso_data.refresh_token @@ -2387,6 +2389,7 @@ def sso( or user_data.get("mail") or user_data.get("login") ) + if not account_name: logging.error( f"Could not get account identifier from {provider.capitalize()} response." @@ -2395,6 +2398,7 @@ def sso( status_code=400, detail=f"Could not get account identifier from {provider.capitalize()} response.", ) + session = get_session() try: provider_record = ( @@ -2406,6 +2410,10 @@ def sso( provider_record = OAuthProvider(name=provider) session.add(provider_record) session.commit() + + # Initialize mfa_token as None + mfa_token = None + # Check for existing OAuth connection existing_oauth = ( session.query(UserOAuth) @@ -2413,27 +2421,33 @@ def sso( .filter(UserOAuth.account_name == account_name) .first() ) + if existing_oauth: - if self.user_id: - # If user is already logged in and trying to connect a provider that's - # connected to another account, prevent it - if str(existing_oauth.user_id) != str(self.user_id): - raise HTTPException( - status_code=400, - detail=f"This {provider} account is already connected to a different user.", - ) - else: - # Only set user_id if no user is currently logged in - self.user_id = existing_oauth.user_id - user = session.query(User).filter(User.id == self.user_id).first() - self.email = user.email - mfa_token = user.mfa_token + # Get the user associated with this OAuth connection + user = ( + session.query(User) + .filter(User.id == existing_oauth.user_id) + .first() + ) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + self.user_id = str(user.id) + self.email = user.email + mfa_token = user.mfa_token + + if self.user_id and str(existing_oauth.user_id) != str(self.user_id): + raise HTTPException( + status_code=400, + detail=f"This {provider} account is already connected to a different user.", + ) else: # No existing OAuth connection found if not self.user_id: # If no user is logged in, look up or create user by email email = user_data.get("email") or account_name user = session.query(User).filter(User.email == email).first() + if user: self.user_id = str(user.id) self.email = user.email @@ -2450,6 +2464,7 @@ def sso( invitation_id=invitation_id, verify_email=True, ) + if isinstance(registration_response, dict): if "error" in registration_response: raise HTTPException( @@ -2464,7 +2479,17 @@ def sso( else: # User is logged in, get their MFA token user = session.query(User).filter(User.id == self.user_id).first() + if not user: + raise HTTPException(status_code=404, detail="User not found") mfa_token = user.mfa_token + + # Verify we have an MFA token before proceeding + if not mfa_token: + raise HTTPException( + status_code=500, + detail="Failed to get or generate MFA token", + ) + # Update or create OAuth connection self.update_sso( account_name=account_name, @@ -2473,6 +2498,7 @@ def sso( token_expires_at=token_expires_at, refresh_token=refresh_token, ) + totp = pyotp.TOTP(mfa_token) login = Login(email=self.email, token=totp.now()) return self.send_magic_link(