diff --git a/litellm/proxy/management_endpoints/scim/scim_v2.py b/litellm/proxy/management_endpoints/scim/scim_v2.py index fd7d150a7f4f..b8f6b4a44602 100644 --- a/litellm/proxy/management_endpoints/scim/scim_v2.py +++ b/litellm/proxy/management_endpoints/scim/scim_v2.py @@ -4,7 +4,6 @@ This is an enterprise feature and requires a premium license. """ -from litellm._uuid import uuid from typing import Any, Dict, List, Optional, Set, Tuple from fastapi import ( @@ -17,11 +16,12 @@ Request, Response, ) -from typing_extensions import TypedDict from pydantic import BaseModel +from typing_extensions import TypedDict import litellm from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( LiteLLM_TeamTable, @@ -51,32 +51,31 @@ class UserProvisionerHelpers: """Helper methods for user provisioning operations.""" - + @staticmethod async def handle_existing_user_by_email( - prisma_client, - new_user_request: NewUserRequest + prisma_client, new_user_request: NewUserRequest ) -> Optional[SCIMUser]: """ Check if a user with the given email already exists and update them if found. - + Args: prisma_client: Database client new_user_request: New user request data - + Returns: SCIMUser if user was updated, None if no existing user found """ if not new_user_request.user_email: return None - + existing_user = await prisma_client.db.litellm_usertable.find_first( where={"user_email": new_user_request.user_email} ) - + if not existing_user: return None - + # Update the user updated_user = await prisma_client.db.litellm_usertable.update( where={"user_id": existing_user.user_id}, @@ -88,12 +87,15 @@ async def handle_existing_user_by_email( "metadata": safe_dumps(new_user_request.metadata), }, ) - - return await ScimTransformations.transform_litellm_user_to_scim_user(updated_user) + + return await ScimTransformations.transform_litellm_user_to_scim_user( + updated_user + ) class ScimUserData(TypedDict): """Typed structure for extracted SCIM user data.""" + user_email: Optional[str] user_alias: Optional[str] sso_user_id: Optional[str] @@ -105,6 +107,7 @@ class ScimUserData(TypedDict): class GroupMemberExtractionResult(BaseModel): """Result of extracting and processing group members.""" + existing_member_ids: List[str] created_users: List[NewUserResponse] all_member_ids: List[str] # existing + newly created @@ -121,7 +124,7 @@ class GroupMemberExtractionResult(BaseModel): async def _get_prisma_client_or_raise_exception(): """Check if database is connected and raise HTTPException if not.""" from litellm.proxy.proxy_server import prisma_client - + if prisma_client is None: raise HTTPException(status_code=500, detail={"error": "No database connected"}) return prisma_client @@ -130,32 +133,32 @@ async def _get_prisma_client_or_raise_exception(): async def _check_user_exists(user_id: str): """Check if user exists and return user, raise 404 if not found.""" prisma_client = await _get_prisma_client_or_raise_exception() - + user = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id} ) - + if not user: raise HTTPException( status_code=404, detail={"error": f"User not found with ID: {user_id}"} ) - + return user async def _check_team_exists(team_id: str): """Check if team exists and return team, raise 404 if not found.""" prisma_client = await _get_prisma_client_or_raise_exception() - + team = await prisma_client.db.litellm_teamtable.find_unique( where={"team_id": team_id} ) - + if not team: raise HTTPException( status_code=404, detail={"error": f"Group not found with ID: {team_id}"} ) - + return team @@ -184,7 +187,9 @@ def _extract_scim_user_data(user: SCIMUser) -> ScimUserData: } -def _build_scim_metadata(given_name: Optional[str], family_name: Optional[str], active: Optional[bool] = None) -> Dict[str, Any]: +def _build_scim_metadata( + given_name: Optional[str], family_name: Optional[str], active: Optional[bool] = None +) -> Dict[str, Any]: """Build metadata dictionary with SCIM data.""" metadata: Dict[str, Any] = { "scim_metadata": LiteLLM_UserScimMetadata( @@ -192,17 +197,17 @@ def _build_scim_metadata(given_name: Optional[str], family_name: Optional[str], familyName=family_name, ).model_dump() } - + if active is not None: metadata["scim_active"] = active - + return metadata async def _extract_group_member_ids(group: SCIMGroup) -> GroupMemberExtractionResult: """ Extract member IDs from SCIMGroup, creating users that don't exist. - + Returns: GroupMemberExtractionResult with existing members, created users, and all member IDs """ @@ -210,35 +215,34 @@ async def _extract_group_member_ids(group: SCIMGroup) -> GroupMemberExtractionRe existing_member_ids = [] created_users = [] all_member_ids = [] - + if group.members: for member in group.members: user_id = member.value - + # Check if user exists user = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": user_id} ) - + if user: existing_member_ids.append(user_id) all_member_ids.append(user_id) else: # Create the user if they don't exist using our helper created_user = await _create_user_if_not_exists( - user_id=user_id, - created_via="scim_group_membership" + user_id=user_id, created_via="scim_group_membership" ) - + if created_user: created_users.append(created_user) all_member_ids.append(user_id) # If creation failed, user is skipped (logged in helper) - + return GroupMemberExtractionResult( existing_member_ids=existing_member_ids, created_users=created_users, - all_member_ids=all_member_ids + all_member_ids=all_member_ids, ) @@ -246,7 +250,7 @@ async def _get_team_members_display(member_ids: List[str]) -> List[SCIMMember]: """Get SCIMMember objects with display names for a list of member IDs.""" prisma_client = await _get_prisma_client_or_raise_exception() members: List[SCIMMember] = [] - + for member_id in member_ids: user = await prisma_client.db.litellm_usertable.find_unique( where={"user_id": member_id} @@ -254,18 +258,20 @@ async def _get_team_members_display(member_ids: List[str]) -> List[SCIMMember]: if user: display_name = user.user_email or user.user_id members.append(SCIMMember(value=user.user_id, display=display_name)) - + return members -async def _handle_team_membership_changes(user_id: str, existing_teams: List[str], new_teams: List[str]) -> None: +async def _handle_team_membership_changes( + user_id: str, existing_teams: List[str], new_teams: List[str] +) -> None: """Handle adding/removing user from teams based on changes.""" existing_teams_set = set(existing_teams) new_teams_set = set(new_teams) - + teams_to_add = new_teams_set - existing_teams_set teams_to_remove = existing_teams_set - new_teams_set - + if teams_to_add or teams_to_remove: await patch_team_membership( user_id=user_id, @@ -274,19 +280,21 @@ async def _handle_team_membership_changes(user_id: str, existing_teams: List[str ) -async def _create_user_if_not_exists(user_id: str, created_via: str = "scim_group") -> Optional[NewUserResponse]: +async def _create_user_if_not_exists( + user_id: str, created_via: str = "scim_group" +) -> Optional[NewUserResponse]: """ Helper function to create a user if they don't exist. - + Args: user_id: The user ID to create created_via: Context for where the user was created from - + Returns: LiteLLM_UserTable if user was created, None if creation failed """ from litellm.proxy.management_endpoints.internal_user_endpoints import new_user - + try: # Get default role for new internal users default_role: Optional[ @@ -313,7 +321,7 @@ async def _create_user_if_not_exists(user_id: str, created_via: str = "scim_grou created_user = await new_user(data=new_user_request) verbose_proxy_logger.info(f"Created user {user_id} via {created_via}") return created_user - + except Exception as e: verbose_proxy_logger.exception(f"Failed to create user {user_id}: {e}") return None @@ -324,7 +332,7 @@ async def _get_team_member_user_ids_from_team(team: LiteLLM_TeamTable) -> List[s Get the IDs of the members from a team. Use one source of truth for the member IDs: team.members_with_roles - + """ member_user_ids: List[str] = [] for member in team.members_with_roles or []: @@ -337,7 +345,6 @@ async def _get_team_member_user_ids_from_team(team: LiteLLM_TeamTable) -> List[s return member_user_ids - # Dependency to set the correct SCIM Content-Type async def set_scim_content_type(response: Response): """Sets the Content-Type header to application/scim+json""" @@ -450,7 +457,7 @@ async def get_user( verbose_proxy_logger.debug("SCIM GET USER request for user_id=%s", user_id) try: user = await _check_user_exists(user_id) - + # Convert to SCIM format scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(user) return scim_user @@ -458,6 +465,7 @@ async def get_user( except Exception as e: raise handle_exception_on_proxy(e) + @scim_router.post( "/Users", response_model=SCIMUser, @@ -471,11 +479,9 @@ async def create_user( Create a user according to SCIM v2 protocol """ try: - verbose_proxy_logger.debug( - "SCIM CREATE USER request: %s", user.model_dump() - ) + verbose_proxy_logger.debug("SCIM CREATE USER request: %s", user.model_dump()) prisma_client = await _get_prisma_client_or_raise_exception() - + # Extract data from SCIM user user_data = _extract_scim_user_data(user) @@ -487,20 +493,24 @@ async def create_user( if existing_user: raise HTTPException( status_code=409, - detail={"error": f"User already exists with username: {user.userName}"}, + detail={ + "error": f"User already exists with username: {user.userName}" + }, ) # Create user in database user_id = user.userName or str(uuid.uuid4()) - metadata = _build_scim_metadata(user_data["given_name"], user_data["family_name"]) + metadata = _build_scim_metadata( + user_data["given_name"], user_data["family_name"] + ) default_role: Optional[ Literal[ - LitellmUserRoles.PROXY_ADMIN, - LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, - LitellmUserRoles.INTERNAL_USER, - LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, - ] + LitellmUserRoles.PROXY_ADMIN, + LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, + LitellmUserRoles.INTERNAL_USER, + LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, + ] ] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY if litellm.default_internal_user_params: default_role = litellm.default_internal_user_params.get("user_role") @@ -517,22 +527,23 @@ async def create_user( # Check if user with email already exists and update if found existing_user_scim = await UserProvisionerHelpers.handle_existing_user_by_email( - prisma_client=prisma_client, - new_user_request=new_user_request + prisma_client=prisma_client, new_user_request=new_user_request ) - + if existing_user_scim: return existing_user_scim created_user = await new_user( data=new_user_request, ) - + scim_user = await ScimTransformations.transform_litellm_user_to_scim_user( user=created_user ) return scim_user - except HTTPException as e: # allow exceptions like SCIMUserAlreadyExists to be raised + except ( + HTTPException + ) as e: # allow exceptions like SCIMUserAlreadyExists to be raised raise e except Exception as e: raise handle_exception_on_proxy(e) @@ -564,18 +575,16 @@ async def update_user( # Extract data from SCIM user user_data = _extract_scim_user_data(user) - # Build metadata with SCIM data + # Build metadata with SCIM data metadata = _build_scim_metadata( - user_data["given_name"], - user_data["family_name"], - user_data["active"] + user_data["given_name"], user_data["family_name"], user_data["active"] ) # Handle team membership changes await _handle_team_membership_changes( user_id=user_id, existing_teams=existing_user.teams or [], - new_teams=user_data["teams"] + new_teams=user_data["teams"], ) # Update user with all new data (full replacement) @@ -590,6 +599,7 @@ async def update_user( # Serialize metadata to JSON string for Prisma to avoid GraphQL parsing issues if "metadata" in update_data and isinstance(update_data["metadata"], dict): from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + update_data["metadata"] = safe_dumps(update_data["metadata"]) updated_user = await prisma_client.db.litellm_usertable.update( @@ -598,8 +608,10 @@ async def update_user( ) # Convert back to SCIM format - scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(updated_user) - + scim_user = await ScimTransformations.transform_litellm_user_to_scim_user( + updated_user + ) + return scim_user except Exception as e: @@ -617,9 +629,7 @@ async def delete_user( """ Delete a user according to SCIM v2 protocol """ - verbose_proxy_logger.debug( - "SCIM DELETE USER request for user_id=%s", user_id - ) + verbose_proxy_logger.debug("SCIM DELETE USER request for user_id=%s", user_id) try: prisma_client = await _get_prisma_client_or_raise_exception() existing_user = await _check_user_exists(user_id) @@ -668,7 +678,9 @@ def _extract_group_values(value: Any) -> List[str]: return group_values -def _handle_displayname_update(op_type: str, value: Any, update_data: Dict[str, Any]) -> None: +def _handle_displayname_update( + op_type: str, value: Any, update_data: Dict[str, Any] +) -> None: """Handle displayname updates.""" if op_type == "remove": update_data["user_alias"] = None @@ -676,7 +688,9 @@ def _handle_displayname_update(op_type: str, value: Any, update_data: Dict[str, update_data["user_alias"] = str(value) -def _handle_externalid_update(op_type: str, value: Any, update_data: Dict[str, Any]) -> None: +def _handle_externalid_update( + op_type: str, value: Any, update_data: Dict[str, Any] +) -> None: """Handle externalid updates.""" if op_type == "remove": update_data["sso_user_id"] = None @@ -697,7 +711,9 @@ def _handle_active_update(op_type: str, value: Any, metadata: Dict[str, Any]) -> metadata["scim_active"] = bool_val -def _handle_name_update(path: str, op_type: str, value: Any, scim_metadata: Dict[str, Any]) -> None: +def _handle_name_update( + path: str, op_type: str, value: Any, scim_metadata: Dict[str, Any] +) -> None: """Handle name field updates (givenName, familyName).""" if path == "name.givenname": if op_type == "remove": @@ -711,7 +727,9 @@ def _handle_name_update(path: str, op_type: str, value: Any, scim_metadata: Dict scim_metadata["familyName"] = str(value) -def _handle_group_operations(op_type: str, value: Any, teams_set: Set[str]) -> Optional[Set[str]]: +def _handle_group_operations( + op_type: str, value: Any, teams_set: Set[str] +) -> Optional[Set[str]]: """Handle group/team membership operations.""" group_values = _extract_group_values(value) if op_type == "replace": @@ -724,7 +742,9 @@ def _handle_group_operations(op_type: str, value: Any, teams_set: Set[str]) -> O return None -def _handle_generic_metadata(path: str, op_type: str, value: Any, metadata: Dict[str, Any]) -> None: +def _handle_generic_metadata( + path: str, op_type: str, value: Any, metadata: Dict[str, Any] +) -> None: """Handle generic metadata operations for unknown paths.""" if op_type == "remove": metadata.pop(path, None) @@ -769,6 +789,7 @@ def _apply_patch_ops( update_data["metadata"] = metadata return update_data, final_team_set + async def patch_team_membership( user_id: str, teams_ids_to_add_user_to: List[str], @@ -778,29 +799,35 @@ async def patch_team_membership( Add or remove user from teams """ for _team_id in teams_ids_to_add_user_to: - try: - await team_member_add( - data=TeamMemberAddRequest( - team_id=_team_id, - member=Member(user_id=user_id, role="user"), - ), - user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), - ) - except Exception as e: - verbose_proxy_logger.exception(f"Error adding user to team {_team_id}: {e}") + try: + await team_member_add( + data=TeamMemberAddRequest( + team_id=_team_id, + member=Member(user_id=user_id, role="user"), + ), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN + ), + ) + except Exception as e: + verbose_proxy_logger.exception(f"Error adding user to team {_team_id}: {e}") for _team_id in teams_ids_to_remove_user_from: try: await team_member_delete( data=TeamMemberDeleteRequest(team_id=_team_id, user_id=user_id), - user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN + ), ) except Exception as e: - verbose_proxy_logger.exception(f"Error removing user from team {_team_id}: {e}") - + verbose_proxy_logger.exception( + f"Error removing user from team {_team_id}: {e}" + ) return True + @scim_router.patch( "/Users/{user_id}", response_model=SCIMUser, @@ -833,7 +860,7 @@ async def patch_user( await _handle_team_membership_changes( user_id=user_id, existing_teams=existing_user.teams or [], - new_teams=list(final_team_set) + new_teams=list(final_team_set), ) update_data["teams"] = list(final_team_set) @@ -841,6 +868,7 @@ async def patch_user( # Serialize metadata to JSON string for Prisma to avoid GraphQL parsing issues if "metadata" in update_data and isinstance(update_data["metadata"], dict): from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + update_data["metadata"] = safe_dumps(update_data["metadata"]) updated_user = await prisma_client.db.litellm_usertable.update( @@ -848,7 +876,9 @@ async def patch_user( data=update_data, ) - scim_user = await ScimTransformations.transform_litellm_user_to_scim_user(updated_user) + scim_user = await ScimTransformations.transform_litellm_user_to_scim_user( + updated_user + ) return scim_user @@ -947,9 +977,7 @@ async def get_group( """ Get a single group by ID according to SCIM v2 protocol """ - verbose_proxy_logger.debug( - "SCIM GET GROUP request for group_id=%s", group_id - ) + verbose_proxy_logger.debug("SCIM GET GROUP request for group_id=%s", group_id) try: team = await _check_team_exists(group_id) @@ -981,9 +1009,9 @@ async def create_group( ) try: prisma_client = await _get_prisma_client_or_raise_exception() - + # Generate ID if not provided - team_id = group.id or str(uuid.uuid4()) + team_id = group.id or group.externalId or str(uuid.uuid4()) # Check if team already exists existing_team = await prisma_client.db.litellm_teamtable.find_unique( @@ -998,7 +1026,10 @@ async def create_group( # Extract and process group members (creating users that don't exist) member_result = await _extract_group_member_ids(group) - members_with_roles = [Member(user_id=member_id, role="user") for member_id in member_result.all_member_ids] + members_with_roles = [ + Member(user_id=member_id, role="user") + for member_id in member_result.all_member_ids + ] # Create team in database created_team = await new_team( @@ -1043,13 +1074,17 @@ async def update_group( # Extract and process group members (creating users that don't exist) member_result = await _extract_group_member_ids(group) - verbose_proxy_logger.debug(f"SCIM PUT GROUP all_member_ids: {member_result.all_member_ids}") - verbose_proxy_logger.debug(f"SCIM PUT GROUP created_users: {len(member_result.created_users)}") + verbose_proxy_logger.debug( + f"SCIM PUT GROUP all_member_ids: {member_result.all_member_ids}" + ) + verbose_proxy_logger.debug( + f"SCIM PUT GROUP created_users: {len(member_result.created_users)}" + ) # Prepare update data existing_metadata = existing_team.metadata if existing_team.metadata else {} updated_metadata = {**existing_metadata, "scim_data": group.model_dump()} - + update_data = { "team_alias": group.displayName, "metadata": safe_dumps(updated_metadata), @@ -1066,7 +1101,7 @@ async def update_group( verbose_proxy_logger.debug(f"SCIM PUT GROUP current_members: {current_members}") final_members = set(member_result.all_member_ids) verbose_proxy_logger.debug(f"SCIM PUT GROUP final_members: {final_members}") - + await _handle_group_membership_changes( group_id=group_id, current_members=current_members, @@ -1094,9 +1129,7 @@ async def delete_group( """ Delete a group according to SCIM v2 protocol """ - verbose_proxy_logger.debug( - "SCIM DELETE GROUP request for group_id=%s", group_id - ) + verbose_proxy_logger.debug("SCIM DELETE GROUP request for group_id=%s", group_id) try: prisma_client = await _get_prisma_client_or_raise_exception() existing_team = await _check_team_exists(group_id) @@ -1124,21 +1157,19 @@ async def delete_group( async def _process_group_patch_operations( - patch_ops: SCIMPatchOp, - existing_team, - prisma_client + patch_ops: SCIMPatchOp, existing_team, prisma_client ) -> Tuple[Dict[str, Any], Set[str]]: """Process patch operations for a group and return update data and final members.""" update_data: Dict[str, Any] = {} - + # Create a fresh copy of existing metadata to avoid Prisma issues existing_metadata = existing_team.metadata or {} metadata = dict(existing_metadata) if existing_metadata else {} - + # Track member changes current_members = set(existing_team.members or []) final_members = current_members.copy() - + # Process each patch operation for op in patch_ops.Operations: path = (op.path or "").lower() @@ -1169,14 +1200,13 @@ async def _process_group_patch_operations( else: # Create the user if they don't exist using our helper created_user = await _create_user_if_not_exists( - user_id=member_id, - created_via="scim_group_patch" + user_id=member_id, created_via="scim_group_patch" ) - + if created_user: valid_members.append(member_id) # If creation failed, user is skipped (logged in helper) - + if op_type == "replace": final_members = set(valid_members) elif op_type == "add": @@ -1194,21 +1224,18 @@ async def _process_group_patch_operations( # Include metadata in update data if it exists if metadata: update_data["metadata"] = metadata - + return update_data, final_members async def _apply_group_patch_updates( - group_id: str, - update_data: Dict[str, Any], - final_members: Set[str], - prisma_client + group_id: str, update_data: Dict[str, Any], final_members: Set[str], prisma_client ): """Apply patch updates to the group in the database.""" # Serialize metadata if present if "metadata" in update_data and isinstance(update_data["metadata"], dict): update_data["metadata"] = safe_dumps(update_data["metadata"]) - + # Update members list update_data["members"] = list(final_members) @@ -1217,22 +1244,20 @@ async def _apply_group_patch_updates( where={"team_id": group_id}, data=update_data, ) - + return updated_team async def _handle_group_membership_changes( - group_id: str, - current_members: Set[str], - final_members: Set[str] + group_id: str, current_members: Set[str], final_members: Set[str] ): """Handle adding/removing members from the group.""" members_to_add = final_members - current_members members_to_remove = current_members - final_members - + verbose_proxy_logger.debug(f"members_to_add: {members_to_add}") verbose_proxy_logger.debug(f"members_to_remove: {members_to_remove}") - + # Use existing helper functions for team membership changes for member_id in members_to_add: await patch_team_membership( @@ -1276,7 +1301,7 @@ async def patch_group( update_data, final_members = await _process_group_patch_operations( patch_ops, existing_team, prisma_client ) - + # Track current members for comparison current_members = set(await _get_team_member_user_ids_from_team(existing_team)) @@ -1286,9 +1311,7 @@ async def patch_group( ) # Handle user-team relationship changes - await _handle_group_membership_changes( - group_id, current_members, final_members - ) + await _handle_group_membership_changes(group_id, current_members, final_members) # Convert to SCIM format and return scim_group = await ScimTransformations.transform_litellm_team_to_scim_group( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 227ea4db9ff4..3bf5f0e5e447 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -5,7 +5,7 @@ import traceback from base64 import b64encode from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from urllib.parse import urlencode, urlparse import httpx @@ -66,7 +66,7 @@ pass_through_endpoint_logging = PassThroughEndpointLogging() # Global registry to track registered pass-through routes and prevent memory leaks -_registered_pass_through_routes: Dict[str, Dict[str, str]] = {} +_registered_pass_through_routes: Dict[str, Dict[str, Union[str, Dict[str, Any]]]] = {} def get_response_body(response: httpx.Response) -> Optional[dict]: @@ -974,24 +974,73 @@ async def endpoint_func( # type: ignore ] = None, # if pass-through endpoint is a streaming request subpath: str = "", # captures sub-paths when include_subpath=True ): + + from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + InitPassThroughEndpointHelpers, + ) + + if not InitPassThroughEndpointHelpers.is_registered_pass_through_route( + route=endpoint + ): + raise HTTPException( + status_code=404, + detail=f"Pass-through endpoint {endpoint} not found. This could have been deleted or not yet added to the proxy.", + ) + + passthrough_params = ( + InitPassThroughEndpointHelpers.get_registered_pass_through_route( + route=endpoint + ) + ) + target_params = { + "target": target, + "custom_headers": custom_headers, + "forward_headers": _forward_headers, + "merge_query_params": _merge_query_params, + "cost_per_request": cost_per_request, + } + + if passthrough_params is not None: + target_params.update(passthrough_params.get("passthrough_params", {})) + + # Extract and cast parameters with proper types + param_target = target_params.get("target") or target + param_custom_headers = target_params.get("custom_headers", custom_headers) + param_forward_headers = target_params.get( + "forward_headers", _forward_headers + ) + param_merge_query_params = target_params.get( + "merge_query_params", _merge_query_params + ) + param_cost_per_request = target_params.get( + "cost_per_request", cost_per_request + ) + # Construct the full target URL with subpath if needed full_target = ( HttpPassThroughEndpointHelpers.construct_target_url_with_subpath( - base_target=target, subpath=subpath, include_subpath=include_subpath + base_target=cast(str, param_target), + subpath=subpath, + include_subpath=include_subpath, ) ) + # Ensure custom_headers is a dict + headers_dict = ( + param_custom_headers if isinstance(param_custom_headers, dict) else {} + ) + return await pass_through_request( # type: ignore request=request, target=full_target, - custom_headers=custom_headers or {}, + custom_headers=headers_dict, user_api_key_dict=user_api_key_dict, - forward_headers=_forward_headers, - merge_query_params=_merge_query_params, + forward_headers=cast(Optional[bool], param_forward_headers), + merge_query_params=cast(Optional[bool], param_merge_query_params), query_params=query_params, stream=stream, custom_body=custom_body, - cost_per_request=cost_per_request, + cost_per_request=cast(Optional[float], param_cost_per_request), custom_llm_provider=custom_llm_provider, ) @@ -1592,6 +1641,14 @@ def add_exact_path_route( "endpoint_id": endpoint_id, "path": path, "type": "exact", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + }, } @staticmethod @@ -1645,6 +1702,14 @@ def add_subpath_route( "endpoint_id": endpoint_id, "path": path, "type": "subpath", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + }, } @staticmethod @@ -1661,6 +1726,11 @@ def remove_endpoint_routes(endpoint_id: str): "Removed pass-through route from registry: %s", key ) + @staticmethod + def clear_all_pass_through_routes(): + """Clear all pass-through routes from the registry""" + _registered_pass_through_routes.clear() + @staticmethod def is_registered_pass_through_route(route: str) -> bool: """ @@ -1694,11 +1764,42 @@ def is_registered_pass_through_route(route: str) -> bool: return False + @staticmethod + def get_registered_pass_through_route(route: str) -> Optional[Dict[str, Any]]: + """Get passthrough params for a given route""" + for key in _registered_pass_through_routes.keys(): + parts = key.split(":", 2) # Split into [endpoint_id, type, path] + if len(parts) == 3: + route_type = parts[1] + registered_path = parts[2] + + if route_type == "exact" and route == registered_path: + return _registered_pass_through_routes[key] + elif route_type == "subpath": + if route == registered_path or route.startswith( + registered_path + "/" + ): + return _registered_pass_through_routes[key] + + return None + + +def _get_combined_pass_through_endpoints( + pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], + config_pass_through_endpoints: List[Dict], +): + """Get combined pass-through endpoints from db + config""" + return pass_through_endpoints + config_pass_through_endpoints + async def initialize_pass_through_endpoints( pass_through_endpoints: Union[List[Dict], List[PassThroughGenericEndpoint]], ): """ + 1. Create a global list of pass-through endpoints (db + config) + 2. Clear all existing pass-through endpoints from the FastAPI app routes + 3. Add new endpoints to the in-memory registry + Initialize a list of pass-through endpoints by adding them to the FastAPI app routes Args: @@ -1711,9 +1812,22 @@ async def initialize_pass_through_endpoints( verbose_proxy_logger.debug("initializing pass through endpoints") from litellm.proxy._types import CommonProxyErrors, LiteLLMRoutes - from litellm.proxy.proxy_server import app, premium_user + from litellm.proxy.proxy_server import app, general_settings, premium_user + + ## get combined pass-through endpoints from db + config + config_pass_through_endpoints = general_settings.get("pass_through_endpoints") + combined_pass_through_endpoints: List[Union[Dict, PassThroughGenericEndpoint]] + if config_pass_through_endpoints is not None: + combined_pass_through_endpoints = _get_combined_pass_through_endpoints( # type: ignore + pass_through_endpoints, config_pass_through_endpoints + ) + else: + combined_pass_through_endpoints = pass_through_endpoints # type: ignore - for endpoint in pass_through_endpoints: + ## clear all existing pass-through endpoints from the FastAPI app routes + InitPassThroughEndpointHelpers.clear_all_pass_through_routes() + + for endpoint in combined_pass_through_endpoints: if isinstance(endpoint, PassThroughGenericEndpoint): endpoint = endpoint.model_dump() @@ -1930,6 +2044,7 @@ async def update_pass_through_endpoints( field_value=pass_through_endpoint_data, config_type="general_settings", ) + await update_config_general_settings( data=updated_data, user_api_key_dict=user_api_key_dict )