From 58281f430ab9b56c0e706f93ed3344a17542fdcb Mon Sep 17 00:00:00 2001 From: mxrcooo <91695404+mxrcooo@users.noreply.github.com> Date: Mon, 6 Oct 2025 00:58:31 +0200 Subject: [PATCH] fix(proxy): enforce virtual-key model budgets in routing --- .../proxy/hooks/model_max_budget_limiter.py | 268 +++++++++++++++++- ...test_unit_test_max_model_budget_limiter.py | 131 +++++++++ 2 files changed, 386 insertions(+), 13 deletions(-) diff --git a/litellm/proxy/hooks/model_max_budget_limiter.py b/litellm/proxy/hooks/model_max_budget_limiter.py index ac02c915366a..f2ffef68256e 100644 --- a/litellm/proxy/hooks/model_max_budget_limiter.py +++ b/litellm/proxy/hooks/model_max_budget_limiter.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import Dict, List, Optional, Set import litellm from litellm._logging import verbose_proxy_logger @@ -140,7 +140,163 @@ async def async_filter_deployments( request_kwargs: Optional[dict] = None, parent_otel_span: Optional[Span] = None, # type: ignore ) -> List[dict]: - return healthy_deployments + if len(healthy_deployments) == 0: + return healthy_deployments + + request_metadata: Dict = {} + if isinstance(request_kwargs, dict): + if isinstance(request_kwargs.get("metadata"), dict): + request_metadata = request_kwargs.get("metadata", {}) + elif isinstance(request_kwargs.get("litellm_metadata"), dict): + request_metadata = request_kwargs.get("litellm_metadata", {}) + + user_api_key_model_max_budget = request_metadata.get( + "user_api_key_model_max_budget" + ) + virtual_key_hash: Optional[str] = request_metadata.get("user_api_key_hash") + + if ( + not user_api_key_model_max_budget + or not isinstance(user_api_key_model_max_budget, dict) + or virtual_key_hash is None + ): + return healthy_deployments + + internal_model_max_budget: Dict[str, BudgetConfig] = {} + for budget_model, budget_info in user_api_key_model_max_budget.items(): + try: + if isinstance(budget_info, BudgetConfig): + internal_model_max_budget[budget_model] = budget_info + elif isinstance(budget_info, dict): + internal_model_max_budget[budget_model] = BudgetConfig( + **budget_info + ) + else: + verbose_proxy_logger.debug( + "Unsupported budget info type for model %s: %s", + budget_model, + type(budget_info), + ) + except Exception as e: + verbose_proxy_logger.debug( + "Failed to parse budget config for model %s - %s", + budget_model, + str(e), + ) + + if len(internal_model_max_budget) == 0: + return healthy_deployments + + filtered_deployments: List[dict] = [] + first_violation: Optional[Dict[str, float]] = None + request_model_candidates: List[str] = [model] + if isinstance(request_kwargs, dict): + request_model = request_kwargs.get("model") + if isinstance(request_model, str) and request_model not in request_model_candidates: + request_model_candidates.append(request_model) + + for deployment in healthy_deployments: + deployment_params: Dict = deployment.get("litellm_params", {}) + deployment_model_name: Optional[str] = deployment.get("model_name") + deployment_model: Optional[str] = deployment_params.get("model") + + candidate_models: List[str] = [] + for candidate in request_model_candidates: + if isinstance(candidate, str): + candidate_models.append(candidate) + if isinstance(deployment_model_name, str): + candidate_models.append(deployment_model_name) + if isinstance(deployment_model, str): + candidate_models.append(deployment_model) + + matched_budget: Optional[BudgetConfig] = None + resolved_model_name: Optional[str] = None + + for candidate in candidate_models: + if not isinstance(candidate, str): + continue + budget_config = self._get_request_model_budget_config( + model=candidate, + internal_model_max_budget=internal_model_max_budget, + ) + if budget_config is None and candidate != "": + sanitized_candidate = self._get_model_without_custom_llm_provider( + candidate + ) + budget_config = self._get_request_model_budget_config( + model=sanitized_candidate, + internal_model_max_budget=internal_model_max_budget, + ) + if budget_config is not None: + candidate = sanitized_candidate + + if budget_config is not None: + matched_budget = budget_config + resolved_model_name = candidate + break + + if matched_budget is None or resolved_model_name is None: + filtered_deployments.append(deployment) + continue + + if not matched_budget.max_budget or matched_budget.max_budget <= 0: + filtered_deployments.append(deployment) + continue + + current_spend = await self._get_virtual_key_spend_for_model( + user_api_key_hash=virtual_key_hash, + model=resolved_model_name, + key_budget_config=matched_budget, + ) + + if ( + current_spend is not None + and matched_budget.max_budget is not None + and current_spend >= matched_budget.max_budget + ): + verbose_proxy_logger.debug( + "Filtered deployment %s for virtual key %s due to budget exceed. Spend=%s, Max Budget=%s", + deployment_model, + virtual_key_hash, + current_spend, + matched_budget.max_budget, + ) + if first_violation is None: + first_violation = { + "model": resolved_model_name, + "current_spend": current_spend, + "max_budget": matched_budget.max_budget, + "key_alias": request_metadata.get("user_api_key_alias"), + } + continue + + filtered_deployments.append(deployment) + + if len(filtered_deployments) > 0: + return filtered_deployments + + if first_violation is not None: + key_alias = first_violation.get("key_alias") + if key_alias is not None: + message = ( + "LiteLLM Virtual Key: {}, key_alias: {} exceeded budget for model={}".format( + virtual_key_hash, + key_alias, + first_violation["model"], + ) + ) + else: + message = "LiteLLM Virtual Key: {} exceeded budget for model={}".format( + virtual_key_hash, + first_violation["model"], + ) + raise litellm.BudgetExceededError( + message=message, + current_cost=first_violation["current_spend"], + max_budget=first_violation["max_budget"], + ) + + return filtered_deployments async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): """ @@ -169,21 +325,107 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti user_api_key_model_max_budget, ) return - response_cost: float = standard_logging_payload.get("response_cost", 0) - model = standard_logging_payload.get("model") - - virtual_key = standard_logging_payload.get("metadata").get("user_api_key_hash") - model = standard_logging_payload.get("model") - if virtual_key is not None: - budget_config = BudgetConfig(time_period="1d", budget_limit=0.1) - virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{budget_config.budget_duration}" - virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}" + response_cost: float = float(standard_logging_payload.get("response_cost", 0) or 0) + standard_logging_metadata = standard_logging_payload.get("metadata") or {} + virtual_key = standard_logging_metadata.get("user_api_key_hash") + if virtual_key is None: + verbose_proxy_logger.debug( + "Skipping virtual key budget tracking - `user_api_key_hash` missing in standard logging payload metadata." + ) + return + + proxy_request_model: Optional[str] = None + proxy_request = kwargs.get("proxy_server_request") or {} + if isinstance(proxy_request, dict): + request_body = proxy_request.get("body") + if isinstance(request_body, dict): + proxy_request_model = request_body.get("model") + + internal_model_max_budget: Dict[str, BudgetConfig] = {} + for budget_model, budget_info in user_api_key_model_max_budget.items(): + try: + if isinstance(budget_info, BudgetConfig): + internal_model_max_budget[budget_model] = budget_info + elif isinstance(budget_info, dict): + internal_model_max_budget[budget_model] = BudgetConfig(**budget_info) + else: + verbose_proxy_logger.debug( + "Unsupported budget info type for model %s: %s", + budget_model, + type(budget_info), + ) + except Exception as e: + verbose_proxy_logger.debug( + "Failed to parse budget config for model %s - %s", + budget_model, + str(e), + ) + + if len(internal_model_max_budget) == 0: + verbose_proxy_logger.debug( + "Not tracking virtual key spend - no parsable budget configs." + ) + return + + model_candidates: List[str] = [] + if proxy_request_model: + model_candidates.append(proxy_request_model) + response_model = standard_logging_payload.get("model") + if response_model: + model_candidates.append(response_model) + + seen_models: Set[str] = set() + spend_tracked = False + + for candidate in model_candidates: + if not candidate or candidate in seen_models: + continue + seen_models.add(candidate) + + resolved_model_name = candidate + budget_config = internal_model_max_budget.get(candidate) + if budget_config is None: + sanitized_model = self._get_model_without_custom_llm_provider(candidate) + if sanitized_model != candidate: + budget_config = internal_model_max_budget.get(sanitized_model) + if budget_config is not None: + resolved_model_name = sanitized_model + + if budget_config is None: + verbose_proxy_logger.debug( + "No budget config matched for candidate model %s", + candidate, + ) + continue + + if budget_config.budget_duration is None: + verbose_proxy_logger.debug( + "Budget config for model %s missing `budget_duration`, skipping spend tracking.", + resolved_model_name, + ) + continue + + spend_key = ( + f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{resolved_model_name}:{budget_config.budget_duration}" + ) + start_time_key = ( + f"virtual_key_budget_start_time:{virtual_key}:{resolved_model_name}" + ) + await self._increment_spend_for_key( budget_config=budget_config, - spend_key=virtual_spend_key, - start_time_key=virtual_start_time_key, + spend_key=spend_key, + start_time_key=start_time_key, response_cost=response_cost, ) + spend_tracked = True + + if not spend_tracked: + verbose_proxy_logger.debug( + "Virtual key spend not tracked - no candidate models matched the configured budgets." + ) + return + verbose_proxy_logger.debug( "current state of in memory cache %s", json.dumps( diff --git a/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py b/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py index fc8373a17468..a61d385b82a7 100644 --- a/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py +++ b/tests/proxy_unit_tests/test_unit_test_max_model_budget_limiter.py @@ -123,3 +123,134 @@ async def test_get_virtual_key_spend_for_model(budget_limiter): key_budget_config=budget_config, ) assert spend == 50.0 + + +@pytest.mark.asyncio +async def test_virtual_key_budget_tracking_respects_duration(budget_limiter): + user_api_key = UserAPIKeyAuth( + token="virtual-key", + key_alias="vk-alias", + model_max_budget={ + "gpt-5": {"budget_limit": 1e-9, "time_period": "30d"} + }, + ) + + logging_kwargs = { + "standard_logging_object": { + "response_cost": 6e-10, + "model": "gpt-5", + "metadata": {"user_api_key_hash": user_api_key.token}, + }, + "litellm_params": { + "metadata": { + "user_api_key_model_max_budget": user_api_key.model_max_budget + } + }, + "proxy_server_request": {"body": {"model": "gpt-5"}}, + } + + await budget_limiter.async_log_success_event( + logging_kwargs, None, datetime.now(), datetime.now() + ) + + assert ( + await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-5") + is True + ) + + await budget_limiter.async_log_success_event( + logging_kwargs, None, datetime.now(), datetime.now() + ) + + with pytest.raises(litellm.BudgetExceededError): + await budget_limiter.is_key_within_model_budget(user_api_key, "gpt-5") + + +@pytest.mark.asyncio +async def test_async_filter_deployments_filters_over_budget(budget_limiter): + virtual_key_hash = "vk-over-budget" + user_budget = { + "gpt-5": {"budget_limit": 1e-9, "time_period": "30d"}, + "gpt-5-mini": {"budget_limit": 1e-9, "time_period": "30d"}, + } + + await budget_limiter.dual_cache.async_set_cache( + key=f"virtual_key_spend:{virtual_key_hash}:gpt-5:30d", + value=2e-9, + ) + + healthy_deployments = [ + { + "model_name": "feedback", + "litellm_params": {"model": "openai/gpt-5"}, + }, + { + "model_name": "feedback-fallback", + "litellm_params": {"model": "openai/gpt-5-mini"}, + }, + ] + + request_kwargs = { + "model": "feedback", + "metadata": { + "user_api_key_model_max_budget": user_budget, + "user_api_key_hash": virtual_key_hash, + "user_api_key_alias": "vk-alias", + }, + } + + filtered = await budget_limiter.async_filter_deployments( + model="feedback", + healthy_deployments=healthy_deployments, + messages=None, + request_kwargs=request_kwargs, + ) + + assert len(filtered) == 1 + assert filtered[0]["model_name"] == "feedback-fallback" + + +@pytest.mark.asyncio +async def test_async_filter_deployments_raises_when_all_over_budget(budget_limiter): + virtual_key_hash = "vk-over-budget-both" + user_budget = { + "gpt-5": {"budget_limit": 1e-9, "time_period": "30d"}, + "gpt-5-mini": {"budget_limit": 1e-9, "time_period": "30d"}, + } + + await budget_limiter.dual_cache.async_set_cache( + key=f"virtual_key_spend:{virtual_key_hash}:gpt-5:30d", + value=2e-9, + ) + await budget_limiter.dual_cache.async_set_cache( + key=f"virtual_key_spend:{virtual_key_hash}:gpt-5-mini:30d", + value=2e-9, + ) + + healthy_deployments = [ + { + "model_name": "feedback", + "litellm_params": {"model": "openai/gpt-5"}, + }, + { + "model_name": "feedback-fallback", + "litellm_params": {"model": "openai/gpt-5-mini"}, + }, + ] + + request_kwargs = { + "model": "feedback", + "metadata": { + "user_api_key_model_max_budget": user_budget, + "user_api_key_hash": virtual_key_hash, + "user_api_key_alias": "vk-alias", + }, + } + + with pytest.raises(litellm.BudgetExceededError): + await budget_limiter.async_filter_deployments( + model="feedback", + healthy_deployments=healthy_deployments, + messages=None, + request_kwargs=request_kwargs, + )