Skip to content

Commit

Permalink
Merge branch 'oidc' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
madeofpendletonwool authored Feb 8, 2025
2 parents 5f41489 + b7e7ee7 commit efd1c0b
Show file tree
Hide file tree
Showing 12 changed files with 1,075 additions and 18 deletions.
87 changes: 87 additions & 0 deletions clients/clientapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3244,6 +3244,93 @@ async def api_delete_user(is_admin: bool = Depends(check_if_admin), cnx=Depends(
return {"status": "User deleted"}


class OIDCProviderValues(BaseModel):
provider_name: str
client_id: str
client_secret: str
authorization_url: str
token_url: str
user_info_url: str
redirect_url: str
button_text: str
scope: Optional[str] = "openid email profile"
button_color: Optional[str] = "#000000"
icon_svg: Optional[str] = None

@app.post("/api/data/add_oidc_provider")
async def api_add_oidc_provider(
is_admin: bool = Depends(check_if_admin),
cnx=Depends(get_database_connection),
api_key: str = Depends(get_api_key_from_header),
provider_values: OIDCProviderValues = Body(...)):
try:
provider_id = database_functions.functions.add_oidc_provider(cnx, database_type, (
provider_values.provider_name,
provider_values.client_id,
provider_values.client_secret,
provider_values.authorization_url,
provider_values.token_url,
provider_values.user_info_url,
provider_values.redirect_url,
provider_values.button_text,
provider_values.scope,
provider_values.button_color,
provider_values.icon_svg
))
if not provider_id:
raise HTTPException(
status_code=500,
detail="Failed to create provider - no provider ID returned"
)
return {"detail": "Success", "provider_id": provider_id}
except psycopg.errors.UniqueViolation:
raise HTTPException(
status_code=409,
detail="A provider with this name already exists"
)
except Exception as e:
logging.error(f"Unexpected error adding provider: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while creating the provider: {str(e)}"
)

@app.post("/api/data/remove_oidc_provider")
async def api_remove_oidc_provider(
is_admin: bool = Depends(check_if_admin),
cnx=Depends(get_database_connection),
api_key: str = Depends(get_api_key_from_header),
provider_id: int = Body(...)):
try:
result = database_functions.functions.remove_oidc_provider(cnx, database_type, provider_id)
if not result:
raise HTTPException(
status_code=404,
detail="Provider not found"
)
return {"detail": "Success"}
except Exception as e:
logging.error(f"Unexpected error removing provider: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while removing the provider: {str(e)}"
)

@app.get("/api/data/list_oidc_providers")
async def api_list_oidc_providers(
cnx=Depends(get_database_connection),
api_key: str = Depends(get_api_key_from_header)):
try:
providers = database_functions.functions.list_oidc_providers(cnx, database_type)
return {"providers": providers}
except Exception as e:
logging.error(f"Unexpected error listing providers: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while listing providers: {str(e)}"
)


@app.put("/api/data/user/set_theme")
async def api_set_theme(user_id: int = Body(...), new_theme: str = Body(...), cnx=Depends(get_database_connection),
api_key: str = Depends(get_api_key_from_header)):
Expand Down
119 changes: 119 additions & 0 deletions database_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,125 @@ def add_admin_user(cnx, database_type, user_values):
finally:
cursor.close()

def add_oidc_provider(cnx, database_type, provider_values):
cursor = cnx.cursor()
try:
if database_type == "postgresql":
add_provider_query = """
INSERT INTO "OIDCProviders"
(ProviderName, ClientID, ClientSecret, AuthorizationURL,
TokenURL, UserInfoURL, RedirectURL, ButtonText,
Scope, ButtonColor, IconSVG)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING ProviderID
"""
else: # MySQL
add_provider_query = """
INSERT INTO OIDCProviders
(ProviderName, ClientID, ClientSecret, AuthorizationURL,
TokenURL, UserInfoURL, RedirectURL, ButtonText,
Scope, ButtonColor, IconSVG)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(add_provider_query, provider_values)

if database_type == "postgresql":
result = cursor.fetchone()
if isinstance(result, dict):
provider_id = result.get('providerid') or result.get('ProviderID') or result.get('provider_id')
else:
provider_id = result[0]
else:
provider_id = cursor.lastrowid

cnx.commit()
return provider_id
except Exception as e:
cnx.rollback()
logging.error(f"Error in add_oidc_provider: {str(e)}")
raise
finally:
cursor.close()

def remove_oidc_provider(cnx, database_type, provider_id):
cursor = cnx.cursor()
try:
if database_type == "postgresql":
delete_query = """
DELETE FROM "OIDCProviders"
WHERE ProviderID = %s
"""
else:
delete_query = """
DELETE FROM OIDCProviders
WHERE ProviderID = %s
"""
cursor.execute(delete_query, (provider_id,))
rows_affected = cursor.rowcount
cnx.commit()
return rows_affected > 0
except Exception as e:
cnx.rollback()
logging.error(f"Error in remove_oidc_provider: {str(e)}")
raise
finally:
cursor.close()

def list_oidc_providers(cnx, database_type):
cursor = cnx.cursor()
try:
if database_type == "postgresql":
list_query = """
SELECT ProviderID, ProviderName, ClientID, AuthorizationURL,
TokenURL, UserInfoURL, RedirectURL, ButtonText,
Scope, ButtonColor, IconSVG, Enabled, Created, Modified
FROM "OIDCProviders"
ORDER BY ProviderName
"""
else:
list_query = """
SELECT ProviderID, ProviderName, ClientID, AuthorizationURL,
TokenURL, UserInfoURL, RedirectURL, ButtonText,
Scope, ButtonColor, IconSVG, Enabled, Created, Modified
FROM OIDCProviders
ORDER BY ProviderName
"""
cursor.execute(list_query)

if database_type == "postgresql":
results = cursor.fetchall()
providers = []
for row in results:
if isinstance(row, dict):
providers.append(row)
else:
providers.append({
'provider_id': row[0],
'provider_name': row[1],
'client_id': row[2],
'authorization_url': row[3],
'token_url': row[4],
'user_info_url': row[5],
'redirect_url': row[6],
'button_text': row[7],
'scope': row[8],
'button_color': row[9],
'icon_svg': row[10],
'enabled': row[11],
'created': row[12],
'modified': row[13]
})
else:
columns = [col[0] for col in cursor.description]
providers = [dict(zip(columns, row)) for row in cursor.fetchall()]

return providers
except Exception as e:
logging.error(f"Error in list_oidc_providers: {str(e)}")
raise
finally:
cursor.close()

def get_pinepods_version():
try:
with open('/pinepods/current_version', 'r') as file:
Expand Down
65 changes: 64 additions & 1 deletion startup/setupdatabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,25 @@ def ensure_usernames_lowercase(cnx):
cnx.commit()
cursor.close()

# Function to check and add columns if they don't exist
def add_column_if_not_exists(cursor, table_name, column_name, column_definition):
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.columns
WHERE table_name='{table_name}'
AND column_name='{column_name}'
AND table_schema=DATABASE();
""")
if cursor.fetchone()[0] == 0:
cursor.execute(f"""
ALTER TABLE {table_name}
ADD COLUMN {column_name} {column_definition};
""")
print(f"Column '{column_name}' added to table '{table_name}'")
else:
return

# Execute SQL command to create tables
# Create Users table if it doesn't exist (your existing code)
cursor.execute("""
CREATE TABLE IF NOT EXISTS Users (
UserID INT AUTO_INCREMENT PRIMARY KEY,
Expand All @@ -82,10 +99,56 @@ def ensure_usernames_lowercase(cnx):
GpodderLoginName VARCHAR(255) DEFAULT '',
GpodderToken VARCHAR(255) DEFAULT '',
EnableRSSFeeds TINYINT(1) DEFAULT 0,
auth_type VARCHAR(50) DEFAULT 'standard',
oidc_provider_id INT,
oidc_subject VARCHAR(255),
UNIQUE (Username)
)
""")

# Create OIDCProviders table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS OIDCProviders (
ProviderID INT AUTO_INCREMENT PRIMARY KEY,
ProviderName VARCHAR(255) NOT NULL,
ClientID VARCHAR(255) NOT NULL,
ClientSecret VARCHAR(500) NOT NULL,
AuthorizationURL VARCHAR(255) NOT NULL,
TokenURL VARCHAR(255) NOT NULL,
UserInfoURL VARCHAR(255) NOT NULL,
RedirectURL VARCHAR(255) NOT NULL,
Scope VARCHAR(255) DEFAULT 'openid email profile',
ButtonColor VARCHAR(50) DEFAULT '#000000',
ButtonText VARCHAR(255) NOT NULL,
IconSVG TEXT,
Enabled TINYINT(1) DEFAULT 1,
Created DATETIME DEFAULT CURRENT_TIMESTAMP,
Modified DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
)
""")

# Add new columns to Users table if they don't exist
add_column_if_not_exists(cursor, 'Users', 'auth_type', 'VARCHAR(50) DEFAULT \'standard\'')
add_column_if_not_exists(cursor, 'Users', 'oidc_provider_id', 'INT')
add_column_if_not_exists(cursor, 'Users', 'oidc_subject', 'VARCHAR(255)')

# Check if foreign key exists before adding it
cursor.execute("""
SELECT COUNT(*)
FROM information_schema.table_constraints
WHERE constraint_name = 'fk_oidc_provider'
AND table_name = 'Users'
AND table_schema = DATABASE();
""")
if cursor.fetchone()[0] == 0:
cursor.execute("""
ALTER TABLE Users
ADD CONSTRAINT fk_oidc_provider
FOREIGN KEY (oidc_provider_id)
REFERENCES OIDCProviders(ProviderID);
""")
print("Foreign key constraint 'fk_oidc_provider' added")

# Add EnableRSSFeeds column if it doesn't exist
cursor.execute("""
SELECT COUNT(*)
Expand Down
Loading

0 comments on commit efd1c0b

Please sign in to comment.