diff --git a/sdks/python/src/opik/evaluation/models/base_model.py b/sdks/python/src/opik/evaluation/models/base_model.py index c197ca4c3b..168162cb52 100644 --- a/sdks/python/src/opik/evaluation/models/base_model.py +++ b/sdks/python/src/opik/evaluation/models/base_model.py @@ -57,7 +57,7 @@ def generate_provider_response(self, **kwargs: Any) -> Any: kwargs: arguments required by the provider to generate a response. Returns: - Any: The response from the model provider, which can be of any type depending on the use case and LLM model. + Any: The response from the model provider, which can be of any type depending on the use case and LLM. """ pass @@ -72,6 +72,6 @@ async def agenerate_provider_response(self, **kwargs: Any) -> Any: kwargs: arguments required by the provider to generate a response. Returns: - Any: The response from the model provider, which can be of any type depending on the use case and LLM model. + Any: The response from the model provider, which can be of any type depending on the use case and LLM. """ pass diff --git a/sdks/python/src/opik/evaluation/models/litellm/litellm_chat_model.py b/sdks/python/src/opik/evaluation/models/litellm/litellm_chat_model.py index 023712e92b..6da45124d2 100644 --- a/sdks/python/src/opik/evaluation/models/litellm/litellm_chat_model.py +++ b/sdks/python/src/opik/evaluation/models/litellm/litellm_chat_model.py @@ -37,10 +37,14 @@ def __init__( You can find all possible completion_kwargs parameters here: https://docs.litellm.ai/docs/completion/input. Args: - model_name: The name of the LLM model to be used. - must_support_arguments: A list of arguments that the provider must support. + model_name: The name of the LLM to be used. + This parameter will be passed to `litellm.completion(model=model_name)` so you don't need to pass + the `model` argument separately inside **completion_kwargs. + must_support_arguments: A list of openai-like arguments that the given model + provider pair must support. `litellm.get_supported_openai_params(model_name)` call is used to get supported arguments. If any is missing, ValueError is raised. + You can pass the arguments from the table: https://docs.litellm.ai/docs/completion/input#translated-openai-params + **completion_kwargs: key-value arguments to always pass additionally into `litellm.completion` function. """ @@ -49,8 +53,8 @@ def __init__( self._check_model_name() self._check_must_support_arguments(must_support_arguments) - self._completion_kwargs: Dict[str, Any] = self._filter_supported_params( - completion_kwargs + self._completion_kwargs: Dict[str, Any] = ( + self._remove_unnecessary_not_supported_params(completion_kwargs) ) self._engine = litellm @@ -62,10 +66,6 @@ def supported_params(self) -> Set[str]: ) self._ensure_supported_params(supported_params) - # Add metadata and success_callback as a parameter that is always supported - supported_params.add("metadata") - supported_params.add("success_callback") - return supported_params def _ensure_supported_params(self, params: Set[str]) -> None: @@ -101,18 +101,21 @@ def _check_must_support_arguments(self, args: Optional[List[str]]) -> None: if key not in self.supported_params: raise ValueError(f"Unsupported parameter: '{key}'!") - def _filter_supported_params(self, params: Dict[str, Any]) -> Dict[str, Any]: - valid_params = {} - - for key, value in params.items(): - if key not in self.supported_params: - LOGGER.debug( - f"This model does not support the '{key}' parameter and it has been ignored." - ) - else: - valid_params[key] = value + def _remove_unnecessary_not_supported_params( + self, params: Dict[str, Any] + ) -> Dict[str, Any]: + filtered_params = {**params} + + if ( + "response_format" in params + and "response_format" not in self.supported_params + ): + filtered_params.pop("response_format") + LOGGER.debug( + "This model does not support the response_format parameter and it will be ignored." + ) - return valid_params + return filtered_params def generate_string(self, input: str, **kwargs: Any) -> str: """ @@ -127,7 +130,7 @@ def generate_string(self, input: str, **kwargs: Any) -> str: str: The generated string output. """ - valid_litellm_params = self._filter_supported_params(kwargs) + valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs) request = [ { @@ -154,13 +157,13 @@ def generate_provider_response( kwargs: arguments required by the provider to generate a response. Returns: - Any: The response from the model provider, which can be of any type depending on the use case and LLM model. + Any: The response from the model provider, which can be of any type depending on the use case and LLM. """ # we need to pop messages first, and after we will check the rest params messages = kwargs.pop("messages") - valid_litellm_params = self._filter_supported_params(kwargs) + valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs) all_kwargs = {**self._completion_kwargs, **valid_litellm_params} if opik_monitor.enabled_in_config(): @@ -185,7 +188,7 @@ async def agenerate_string(self, input: str, **kwargs: Any) -> str: str: The generated string output. """ - valid_litellm_params = self._filter_supported_params(kwargs) + valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs) request = [ { @@ -209,13 +212,13 @@ async def agenerate_provider_response(self, **kwargs: Any) -> ModelResponse: kwargs: arguments required by the provider to generate a response. Returns: - Any: The response from the model provider, which can be of any type depending on the use case and LLM model. + Any: The response from the model provider, which can be of any type depending on the use case and LLM. """ # we need to pop messages first, and after we will check the rest params messages = kwargs.pop("messages") - valid_litellm_params = self._filter_supported_params(kwargs) + valid_litellm_params = self._remove_unnecessary_not_supported_params(kwargs) all_kwargs = {**self._completion_kwargs, **valid_litellm_params} if opik_monitor.enabled_in_config():