Skip to content

Commit

Permalink
Merge pull request #421 from madeofpendletonwool/oidc
Browse files Browse the repository at this point in the history
Fully added OIDC Logins
  • Loading branch information
madeofpendletonwool authored Feb 10, 2025
2 parents efd1c0b + ad0beec commit 707be4e
Show file tree
Hide file tree
Showing 16 changed files with 1,446 additions and 56 deletions.
167 changes: 164 additions & 3 deletions clients/clientapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, status, Header, Body, Path, Form, Query, \
security, BackgroundTasks, UploadFile
from fastapi.security import APIKeyHeader, HTTPBasic, HTTPBasicCredentials
from fastapi.responses import PlainTextResponse, JSONResponse, Response, FileResponse, StreamingResponse
from fastapi.responses import PlainTextResponse, JSONResponse, Response, FileResponse, StreamingResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.concurrency import run_in_threadpool
from threading import Lock
Expand Down Expand Up @@ -75,6 +75,7 @@ def sigterm_handler(_signo, _stack_frame):
import database_functions.auth_functions
import database_functions.app_functions
import database_functions.import_progress
import database_functions.oidc_state_manager
import database_functions.valkey_client
import database_functions.youtube

Expand Down Expand Up @@ -3251,10 +3252,10 @@ class OIDCProviderValues(BaseModel):
authorization_url: str
token_url: str
user_info_url: str
redirect_url: str
button_text: str
scope: Optional[str] = "openid email profile"
button_color: Optional[str] = "#000000"
button_text_color: Optional[str] = "#000000"
icon_svg: Optional[str] = None

@app.post("/api/data/add_oidc_provider")
Expand All @@ -3271,10 +3272,10 @@ async def api_add_oidc_provider(
provider_values.authorization_url,
provider_values.token_url,
provider_values.user_info_url,
provider_values.redirect_url,
provider_values.button_text,
provider_values.scope,
provider_values.button_color,
provider_values.button_text_color,
provider_values.icon_svg
))
if not provider_id:
Expand Down Expand Up @@ -3330,6 +3331,20 @@ async def api_list_oidc_providers(
detail=f"An unexpected error occurred while listing providers: {str(e)}"
)

# Public reqeust for login info
@app.get("/api/data/public_oidc_providers")
async def api_public_oidc_providers(cnx=Depends(get_database_connection)):
"""Get minimal OIDC provider info needed for login screen buttons."""
try:
providers = database_functions.functions.get_public_oidc_providers(cnx, database_type)
return {"providers": providers}
except Exception as e:
logging.error(f"Unexpected error getting public provider info: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred: {str(e)}"
)


@app.put("/api/data/user/set_theme")
async def api_set_theme(user_id: int = Body(...), new_theme: str = Body(...), cnx=Depends(get_database_connection),
Expand Down Expand Up @@ -5646,6 +5661,152 @@ async def subscribe_to_youtube_channel(
detail=f"Error subscribing to channel: {str(e)}"
)

@app.post("/api/auth/store_state")
async def store_oidc_state(
request: Request,
):
try:
data = await request.json()
state = data.get('state')
client_id = data.get('client_id')

if not state or not client_id:
raise HTTPException(status_code=400, detail="Missing state or client_id")

success = database_functions.oidc_state_manager.oidc_state_manager.store_state(state, client_id)
if not success:
raise HTTPException(status_code=500, detail="Failed to store state")

return {"status": "success"}
except Exception as e:
logging.error(f"Error storing OIDC state: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to store state")

@app.get("/api/auth/callback")
async def oidc_callback(
request: Request,
code: str,
state: str = None,
cnx=Depends(get_database_connection)
):
try:
base_url = str(request.base_url)[:-1]
frontend_base = base_url.replace('/api', '') # Remove /api for frontend URLs

# Get client_id from query parameters
client_id = database_functions.oidc_state_manager.oidc_state_manager.get_client_id(state)
if not client_id:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=invalid_state"
)

# Get OIDC provider details
provider = database_functions.functions.get_oidc_provider(cnx, database_type, client_id)
if not provider:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=invalid_provider"
)

# Unpack provider details
provider_id, client_id, client_secret, token_url, userinfo_url = provider

# Exchange authorization code for access token
async with httpx.AsyncClient() as client:
try:
token_response = await client.post(
token_url,
data={
"grant_type": "authorization_code",
"code": code,
"redirect_uri": f"{base_url}/api/auth/callback",
"client_id": client_id,
"client_secret": client_secret,
}
)

if token_response.status_code != 200:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=token_exchange_failed"
)

token_data = token_response.json()
access_token = token_data.get("access_token")

# Get user info from OIDC provider
headers = {"Authorization": f"Bearer {access_token}"}
userinfo_response = await client.get(userinfo_url, headers=headers)

if userinfo_response.status_code != 200:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=userinfo_failed"
)

user_info = userinfo_response.json()
email = user_info.get("email")

if not email:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=email_required"
)

except httpx.RequestError:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=network_error"
)

# Check if user exists
user = database_functions.functions.get_user_by_email(cnx, database_type, email)

if not user:
# Create new user
fullname = user_info.get("name", "")
username = email.split("@")[0].lower()
base_username = username
counter = 1
max_attempts = 10

while counter <= max_attempts:
try:
user_id = database_functions.functions.create_oidc_user(
cnx, database_type, email, fullname, username
)
api_key = database_functions.functions.create_api_key(cnx, database_type, user_id)
break
except UniqueViolation:
username = f"{base_username}{counter}"
counter += 1
if counter > max_attempts:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=username_conflict"
)
else:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=user_creation_failed"
)
else:
# Handle existing user
user_id = user[0] if isinstance(user, tuple) else user['userid']
existing_api_key = database_functions.functions.get_user_api_key(cnx, database_type, user_id)

if existing_api_key:
api_key = existing_api_key
else:
try:
api_key = database_functions.functions.create_api_key(cnx, database_type, user_id)
except Exception:
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=api_key_creation_failed"
)

# Success case - redirect with API key
return RedirectResponse(url=f"{frontend_base}/oauth/callback?api_key={api_key}")

except Exception as e:
logging.error(f"OIDC callback error: {str(e)}")
return RedirectResponse(
url=f"{frontend_base}/oauth/callback?error=authentication_failed"
)

class InitRequest(BaseModel):
api_key: str

Expand Down
Loading

0 comments on commit 707be4e

Please sign in to comment.