Skip to content

Commit

Permalink
[OPIK-838] Add proper support for litellm completion kwargs (#1114)
Browse files Browse the repository at this point in the history
* Rework params filtering logic in LiteLLMChatModel. Nothing except for response_format can be discarded if model does not support it.

* Update model_name docstring

* Update the docstring

* Update LLM model -> LLM
  • Loading branch information
alexkuzmik authored Jan 23, 2025
1 parent 890b0d9 commit 73e3f50
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
4 changes: 2 additions & 2 deletions sdks/python/src/opik/evaluation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def generate_provider_response(self, **kwargs: Any) -> Any:
kwargs: arguments required by the provider to generate a response.
Returns:
Any: The response from the model provider, which can be of any type depending on the use case and LLM model.
Any: The response from the model provider, which can be of any type depending on the use case and LLM.
"""
pass

Expand All @@ -72,6 +72,6 @@ async def agenerate_provider_response(self, **kwargs: Any) -> Any:
kwargs: arguments required by the provider to generate a response.
Returns:
Any: The response from the model provider, which can be of any type depending on the use case and LLM model.
Any: The response from the model provider, which can be of any type depending on the use case and LLM.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,14 @@ def __init__(
You can find all possible completion_kwargs parameters here: https://docs.litellm.ai/docs/completion/input.
Args:
model_name: The name of the LLM model to be used.
must_support_arguments: A list of arguments that the provider must support.
model_name: The name of the LLM to be used.
This parameter will be passed to `litellm.completion(model=model_name)` so you don't need to pass
the `model` argument separately inside **completion_kwargs.
must_support_arguments: A list of openai-like arguments that the given model + provider pair must support.
`litellm.get_supported_openai_params(model_name)` call is used to get
supported arguments. If any is missing, ValueError is raised.
You can pass the arguments from the table: https://docs.litellm.ai/docs/completion/input#translated-openai-params
**completion_kwargs: key-value arguments to always pass additionally into `litellm.completion` function.
"""

Expand All @@ -49,8 +53,8 @@ def __init__(
self._check_model_name()
self._check_must_support_arguments(must_support_arguments)

self._completion_kwargs: Dict[str, Any] = self._filter_supported_params(
completion_kwargs
self._completion_kwargs: Dict[str, Any] = (
self._remove_unnecessary_not_supported_params(completion_kwargs)
)

self._engine = litellm
Expand All @@ -62,10 +66,6 @@ def supported_params(self) -> Set[str]:
)
self._ensure_supported_params(supported_params)

# Add metadata and success_callback as a parameter that is always supported
supported_params.add("metadata")
supported_params.add("success_callback")

return supported_params

def _ensure_supported_params(self, params: Set[str]) -> None:
Expand Down Expand Up @@ -101,18 +101,21 @@ def _check_must_support_arguments(self, args: Optional[List[str]]) -> None:
if key not in self.supported_params:
raise ValueError(f"Unsupported parameter: '{key}'!")

def _filter_supported_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
valid_params = {}

for key, value in params.items():
if key not in self.supported_params:
LOGGER.debug(
f"This model does not support the '{key}' parameter and it has been ignored."
)
else:
valid_params[key] = value
def _remove_unnecessary_not_supported_params(
self, params: Dict[str, Any]
) -> Dict[str, Any]:
filtered_params = {**params}

if (
"response_format" in params
and "response_format" not in self.supported_params
):
filtered_params.pop("response_format")
LOGGER.debug(
"This model does not support the response_format parameter and it will be ignored."
)

return valid_params
return filtered_params

def generate_string(self, input: str, **kwargs: Any) -> str:
"""
Expand All @@ -127,7 +130,7 @@ def generate_string(self, input: str, **kwargs: Any) -> str:
str: The generated string output.
"""

valid_litellm_params = self._filter_supported_params(kwargs)
valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs)

request = [
{
Expand All @@ -154,13 +157,13 @@ def generate_provider_response(
kwargs: arguments required by the provider to generate a response.
Returns:
Any: The response from the model provider, which can be of any type depending on the use case and LLM model.
Any: The response from the model provider, which can be of any type depending on the use case and LLM.
"""

# we need to pop messages first, and after we will check the rest params
messages = kwargs.pop("messages")

valid_litellm_params = self._filter_supported_params(kwargs)
valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs)
all_kwargs = {**self._completion_kwargs, **valid_litellm_params}

if opik_monitor.enabled_in_config():
Expand All @@ -185,7 +188,7 @@ async def agenerate_string(self, input: str, **kwargs: Any) -> str:
str: The generated string output.
"""

valid_litellm_params = self._filter_supported_params(kwargs)
valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs)

request = [
{
Expand All @@ -209,13 +212,13 @@ async def agenerate_provider_response(self, **kwargs: Any) -> ModelResponse:
kwargs: arguments required by the provider to generate a response.
Returns:
Any: The response from the model provider, which can be of any type depending on the use case and LLM model.
Any: The response from the model provider, which can be of any type depending on the use case and LLM.
"""

# we need to pop messages first, and after we will check the rest params
messages = kwargs.pop("messages")

valid_litellm_params = self._filter_supported_params(kwargs)
valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs)
all_kwargs = {**self._completion_kwargs, **valid_litellm_params}

if opik_monitor.enabled_in_config():
Expand Down

0 comments on commit 73e3f50

Please sign in to comment.