diff --git a/wren-ai-service/src/providers/__init__.py b/wren-ai-service/src/providers/__init__.py index 7760b6a285..13b762df1b 100644 --- a/wren-ai-service/src/providers/__init__.py +++ b/wren-ai-service/src/providers/__init__.py @@ -72,13 +72,14 @@ def llm_processor(entry: dict) -> dict: returned = {} for model in entry.get("models", []): model_name = f"{entry.get('provider')}.{model.get('model')}" + model_additional_params = { + k: v for k, v in model.items() if k not in ["model", "kwargs"] + } returned[model_name] = { - "provider": entry.get("provider"), - "model": model.get("model"), - "kwargs": model.get("kwargs"), - "api_base": model.get("api_base"), - "api_version": model.get("api_version"), - "api_key_name": model.get("api_key_name"), + "provider": entry["provider"], + "model": model["model"], + "kwargs": model["kwargs"], + **model_additional_params, **others, } return returned diff --git a/wren-ai-service/src/providers/llm/litellm.py b/wren-ai-service/src/providers/llm/litellm.py index fa002c7caa..b9093538ea 100644 --- a/wren-ai-service/src/providers/llm/litellm.py +++ b/wren-ai-service/src/providers/llm/litellm.py @@ -1,6 +1,9 @@ import os from typing import Any, Callable, Dict, List, Optional, Union +from haystack.components.generators.openai_utils import ( + _convert_message_to_openai_format, +) from haystack.dataclasses import ChatMessage, StreamingChunk from litellm import acompletion from litellm.types.utils import ModelResponse @@ -56,7 +59,7 @@ async def _run( messages = [message] openai_formatted_messages = [ - message.to_openai_format() for message in messages + _convert_message_to_openai_format(message) for message in messages ] generation_kwargs = {