From c999b4efe12cc4513ac726cadf5bac870b40498e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 10 Jan 2025 14:16:30 -0800 Subject: [PATCH] (litellm sdk - perf improvement) - use O(1) set lookups for checking llm providers / models (#7672) * fix get model info logic to use O(1) lookups * perf - use O(1) lookup for get llm provider --- litellm/__init__.py | 1 + litellm/litellm_core_utils/get_llm_provider_logic.py | 6 ++++-- litellm/types/utils.py | 4 ++++ litellm/utils.py | 5 ++--- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index fc4121adf6f3..d6ffbb73874f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -861,6 +861,7 @@ def add_known_models(): + azure_text_models ) +model_list_set = set(model_list) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 74e0ddc36230..834e35c733fa 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -141,7 +141,7 @@ def get_llm_provider( # noqa: PLR0915 # check if llm provider part of model name if ( model.split("/", 1)[0] in litellm.provider_list - and model.split("/", 1)[0] not in litellm.model_list + and model.split("/", 1)[0] not in litellm.model_list_set and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351 ): @@ -210,7 +210,9 @@ def get_llm_provider( # noqa: PLR0915 dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY") elif endpoint == "https://api.friendli.ai/serverless/v1": custom_llm_provider = "friendliai" - dynamic_api_key = get_secret_str("FRIENDLIAI_API_KEY") or get_secret("FRIENDLI_TOKEN") + dynamic_api_key = get_secret_str( + "FRIENDLIAI_API_KEY" + ) or get_secret("FRIENDLI_TOKEN") elif endpoint == "api.galadriel.com/v1": custom_llm_provider = "galadriel" dynamic_api_key = get_secret_str("GALADRIEL_API_KEY") diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 39a329f8893a..09567b88b430 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1817,6 +1817,10 @@ class LlmProviders(str, Enum): HUMANLOOP = "humanloop" +# Create a set of all provider values for quick lookup +LlmProvidersSet = {provider.value for provider in LlmProviders} + + class LiteLLMLoggingBaseClass: """ Base class for logging pre and post call diff --git a/litellm/utils.py b/litellm/utils.py index 5aae94e4f031..b46707b48f60 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -133,6 +133,7 @@ Function, ImageResponse, LlmProviders, + LlmProvidersSet, Message, ModelInfo, ModelInfoBase, @@ -4108,9 +4109,7 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None - if custom_llm_provider and custom_llm_provider in [ - provider.value for provider in LlmProviders - ]: + if custom_llm_provider and custom_llm_provider in LlmProvidersSet: # Check if the provider string exists in LlmProviders enum provider_config = ProviderConfigManager.get_provider_model_info( model=model, provider=LlmProviders(custom_llm_provider)