Skip to content

Commit

Permalink
(litellm sdk - perf improvement) - use O(1) set lookups for checking …
Browse files Browse the repository at this point in the history
…llm providers / models (#7672)

* fix get model info logic to use O(1) lookups

* perf - use O(1) lookup for get llm provider
  • Loading branch information
ishaan-jaff authored Jan 10, 2025
1 parent b3bd15e commit c999b4e
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,7 @@ def add_known_models():
+ azure_text_models
)

model_list_set = set(model_list)

provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

Expand Down
6 changes: 4 additions & 2 deletions litellm/litellm_core_utils/get_llm_provider_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_llm_provider( # noqa: PLR0915
# check if llm provider part of model name
if (
model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list
and model.split("/", 1)[0] not in litellm.model_list_set
and len(model.split("/"))
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
):
Expand Down Expand Up @@ -210,7 +210,9 @@ def get_llm_provider( # noqa: PLR0915
dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY")
elif endpoint == "https://api.friendli.ai/serverless/v1":
custom_llm_provider = "friendliai"
dynamic_api_key = get_secret_str("FRIENDLIAI_API_KEY") or get_secret("FRIENDLI_TOKEN")
dynamic_api_key = get_secret_str(
"FRIENDLIAI_API_KEY"
) or get_secret("FRIENDLI_TOKEN")
elif endpoint == "api.galadriel.com/v1":
custom_llm_provider = "galadriel"
dynamic_api_key = get_secret_str("GALADRIEL_API_KEY")
Expand Down
4 changes: 4 additions & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,6 +1817,10 @@ class LlmProviders(str, Enum):
HUMANLOOP = "humanloop"


# Create a set of all provider values for quick lookup
LlmProvidersSet = {provider.value for provider in LlmProviders}


class LiteLLMLoggingBaseClass:
"""
Base class for logging pre and post call
Expand Down
5 changes: 2 additions & 3 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
Function,
ImageResponse,
LlmProviders,
LlmProvidersSet,
Message,
ModelInfo,
ModelInfoBase,
Expand Down Expand Up @@ -4108,9 +4109,7 @@ def _get_model_info_helper( # noqa: PLR0915
):
_model_info = None

if custom_llm_provider and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
# Check if the provider string exists in LlmProviders enum
provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider)
Expand Down

0 comments on commit c999b4e

Please sign in to comment.