diff --git a/litellm/__init__.py b/litellm/__init__.py index fc4121adf6f3..d6ffbb73874f 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -861,6 +861,7 @@ def add_known_models(): + azure_text_models ) +model_list_set = set(model_list) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 74e0ddc36230..834e35c733fa 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -141,7 +141,7 @@ def get_llm_provider( # noqa: PLR0915 # check if llm provider part of model name if ( model.split("/", 1)[0] in litellm.provider_list - and model.split("/", 1)[0] not in litellm.model_list + and model.split("/", 1)[0] not in litellm.model_list_set and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351 ): @@ -210,7 +210,9 @@ def get_llm_provider( # noqa: PLR0915 dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY") elif endpoint == "https://api.friendli.ai/serverless/v1": custom_llm_provider = "friendliai" - dynamic_api_key = get_secret_str("FRIENDLIAI_API_KEY") or get_secret("FRIENDLI_TOKEN") + dynamic_api_key = get_secret_str( + "FRIENDLIAI_API_KEY" + ) or get_secret("FRIENDLI_TOKEN") elif endpoint == "api.galadriel.com/v1": custom_llm_provider = "galadriel" dynamic_api_key = get_secret_str("GALADRIEL_API_KEY") diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 39a329f8893a..09567b88b430 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1817,6 +1817,10 @@ class LlmProviders(str, Enum): HUMANLOOP = "humanloop" +# Create a set of all provider values for quick lookup +LlmProvidersSet = {provider.value for provider in LlmProviders} + + class LiteLLMLoggingBaseClass: """ Base class for logging pre and post call diff --git a/litellm/utils.py b/litellm/utils.py index 5aae94e4f031..b46707b48f60 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -133,6 +133,7 @@ Function, ImageResponse, LlmProviders, + LlmProvidersSet, Message, ModelInfo, ModelInfoBase, @@ -4108,9 +4109,7 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None - if custom_llm_provider and custom_llm_provider in [ - provider.value for provider in LlmProviders - ]: + if custom_llm_provider and custom_llm_provider in LlmProvidersSet: # Check if the provider string exists in LlmProviders enum provider_config = ProviderConfigManager.get_provider_model_info( model=model, provider=LlmProviders(custom_llm_provider)