Skip to content

Commit

Permalink
fix sso on invite
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Jan 25, 2025
1 parent fa42066 commit ddff7ae
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions agixt/MagicalAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,7 @@ def sso(
provider = str(provider).lower()
if provider not in ["amazon", "microsoft", "google", "github", "walmart"]:
provider = "microsoft"

sso_data = get_sso_provider(provider=provider, code=code, redirect_uri=referrer)
if not sso_data:
logging.error(f"Failed to get user data from {provider.capitalize()}.")
Expand All @@ -2352,6 +2353,7 @@ def sso(
status_code=400,
detail=f"Failed to get access token from {provider.capitalize()}.",
)

user_data = sso_data.user_info
access_token = sso_data.access_token
refresh_token = sso_data.refresh_token
Expand Down Expand Up @@ -2387,6 +2389,7 @@ def sso(
or user_data.get("mail")
or user_data.get("login")
)

if not account_name:
logging.error(
f"Could not get account identifier from {provider.capitalize()} response."
Expand All @@ -2395,6 +2398,7 @@ def sso(
status_code=400,
detail=f"Could not get account identifier from {provider.capitalize()} response.",
)

session = get_session()
try:
provider_record = (
Expand All @@ -2406,34 +2410,44 @@ def sso(
provider_record = OAuthProvider(name=provider)
session.add(provider_record)
session.commit()

# Initialize mfa_token as None
mfa_token = None

# Check for existing OAuth connection
existing_oauth = (
session.query(UserOAuth)
.filter(UserOAuth.provider_id == provider_record.id)
.filter(UserOAuth.account_name == account_name)
.first()
)

if existing_oauth:
if self.user_id:
# If user is already logged in and trying to connect a provider that's
# connected to another account, prevent it
if str(existing_oauth.user_id) != str(self.user_id):
raise HTTPException(
status_code=400,
detail=f"This {provider} account is already connected to a different user.",
)
else:
# Only set user_id if no user is currently logged in
self.user_id = existing_oauth.user_id
user = session.query(User).filter(User.id == self.user_id).first()
self.email = user.email
mfa_token = user.mfa_token
# Get the user associated with this OAuth connection
user = (
session.query(User)
.filter(User.id == existing_oauth.user_id)
.first()
)
if not user:
raise HTTPException(status_code=404, detail="User not found")

self.user_id = str(user.id)
self.email = user.email
mfa_token = user.mfa_token

if self.user_id and str(existing_oauth.user_id) != str(self.user_id):
raise HTTPException(
status_code=400,
detail=f"This {provider} account is already connected to a different user.",
)
else:
# No existing OAuth connection found
if not self.user_id:
# If no user is logged in, look up or create user by email
email = user_data.get("email") or account_name
user = session.query(User).filter(User.email == email).first()

if user:
self.user_id = str(user.id)
self.email = user.email
Expand All @@ -2450,6 +2464,7 @@ def sso(
invitation_id=invitation_id,
verify_email=True,
)

if isinstance(registration_response, dict):
if "error" in registration_response:
raise HTTPException(
Expand All @@ -2464,7 +2479,17 @@ def sso(
else:
# User is logged in, get their MFA token
user = session.query(User).filter(User.id == self.user_id).first()
if not user:
raise HTTPException(status_code=404, detail="User not found")
mfa_token = user.mfa_token

# Verify we have an MFA token before proceeding
if not mfa_token:
raise HTTPException(
status_code=500,
detail="Failed to get or generate MFA token",
)

# Update or create OAuth connection
self.update_sso(
account_name=account_name,
Expand All @@ -2473,6 +2498,7 @@ def sso(
token_expires_at=token_expires_at,
refresh_token=refresh_token,
)

totp = pyotp.TOTP(mfa_token)
login = Login(email=self.email, token=totp.now())
return self.send_magic_link(
Expand Down

0 comments on commit ddff7ae

Please sign in to comment.